diff options
author | Arnstein Ressem <aressem@gmail.com> | 2018-09-25 19:13:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-25 19:13:14 +0200 |
commit | dfe56b565f6728a6c537867d28e388cd04718071 (patch) | |
tree | 5e2727aaf1eef5ef5043b399508cfca99b2e33c2 /searchlib | |
parent | f627463a8100090ec109d27c3aeb439a3395a34f (diff) |
Revert "Bratseth/rank type information 2"
Diffstat (limited to 'searchlib')
11 files changed, 67 insertions, 153 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index f6502a9801d..da34ab8822d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -2,11 +2,8 @@ package com.yahoo.searchlib.rankingexpression; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.yahoo.log.event.Collection; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; -import com.yahoo.tensor.TensorType; import com.yahoo.text.Utf8; import java.security.MessageDigest; @@ -16,14 +13,9 @@ import java.util.Deque; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; -import java.util.Optional; /** - * A function defined by a ranking expression, optionally containing type information - * for inputs and outputs. - * - * Immutable, but note that ranking expressions are *not* immutable. + * A function defined by a ranking expression * * @author Simon Thoresen Hult * @author bratseth @@ -32,13 +24,8 @@ public class ExpressionFunction { private final String name; private final ImmutableList<String> arguments; - - /** Types of the inputs, if known. The keys here is any subset (including empty and identity) of the argument list */ - private final ImmutableMap<String, TensorType> argumentTypes; private final RankingExpression body; - private final Optional<TensorType> returnType; - /** * Constructs a new function with no arguments * @@ -57,18 +44,9 @@ public class ExpressionFunction { * @param body the ranking expression that defines this function */ public ExpressionFunction(String name, List<String> arguments, RankingExpression body) { - this(name, arguments, body, ImmutableMap.of(), Optional.empty()); - } - - public ExpressionFunction(String name, List<String> arguments, RankingExpression body, - Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) { - this.name = Objects.requireNonNull(name, "name cannot be null"); + this.name = name; this.arguments = arguments==null ? ImmutableList.of() : ImmutableList.copyOf(arguments); - this.body = Objects.requireNonNull(body, "body cannot be null"); - if ( ! this.arguments.containsAll(argumentTypes.keySet())) - throw new IllegalArgumentException("Argument type keys must be a subset of the argument keys"); - this.argumentTypes = ImmutableMap.copyOf(argumentTypes); - this.returnType = Objects.requireNonNull(returnType, "returnType cannot be null"); + this.body = body; } public String getName() { return name; } @@ -78,27 +56,9 @@ public class ExpressionFunction { public RankingExpression getBody() { return body; } - /** Returns the types of the arguments of this, if specified. The keys of this may be any subset of the arguments */ - public Map<String, TensorType> argumentTypes() { return argumentTypes; } - - /** Returns the return type of this, or empty if not specified */ - public Optional<TensorType> returnType() { return returnType; } - - public ExpressionFunction withName(String name) { - return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); - } - /** Returns a copy of this with the body changed to the given value */ public ExpressionFunction withBody(RankingExpression body) { - return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); - } - - public ExpressionFunction withReturnType(TensorType returnType) { - return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType)); - } - - public ExpressionFunction withArgumentTypes(Map<String, TensorType> argumentTypes) { - return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); + return new ExpressionFunction(name, arguments, body); } /** diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index 9ff391a5cfe..282a4c5e0a9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -1,22 +1,15 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; -import java.util.Optional; import java.util.regex.Pattern; /** @@ -33,11 +26,12 @@ public class ImportedModel { private final String source; private final Map<String, Signature> signatures = new HashMap<>(); - private final Map<String, TensorType> inputs = new HashMap<>(); + private final Map<String, TensorType> arguments = new HashMap<>(); private final Map<String, Tensor> smallConstants = new HashMap<>(); private final Map<String, Tensor> largeConstants = new HashMap<>(); private final Map<String, RankingExpression> expressions = new HashMap<>(); private final Map<String, RankingExpression> functions = new HashMap<>(); + private final Map<String, TensorType> requiredFunctions = new HashMap<>(); /** * Creates a new imported model. @@ -55,11 +49,11 @@ public class ImportedModel { /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ public String name() { return name; } - /** Returns the source path (directory or file) of this model */ + /** Returns the source path (directiry or file) of this model */ public String source() { return source; } - /** Returns an immutable map of the inputs of this */ - public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); } + /** Returns an immutable map of the arguments ("Placeholders") of this */ + public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } /** * Returns an immutable map of the small constants of this. @@ -77,7 +71,7 @@ public class ImportedModel { /** * Returns an immutable map of the expressions of this - corresponding to graph nodes - * which are not Inputs/Placeholders or Variables (which instead become respectively inputs and constants). + * which are not Inputs/Placeholders or Variables (which instead become respectively arguments and constants). * Note that only nodes recursively referenced by a placeholder/input are added. */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } @@ -88,6 +82,9 @@ public class ImportedModel { */ public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); } + /** Returns an immutable map of the functions that must be provided by the environment running this model */ + public Map<String, TensorType> requiredFunctions() { return Collections.unmodifiableMap(requiredFunctions); } + /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } @@ -99,11 +96,12 @@ public class ImportedModel { /** Convenience method for returning a default signature */ Signature defaultSignature() { return signature(defaultSignatureName); } - void input(String name, TensorType argumentType) { inputs.put(name, argumentType); } + void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } void function(String name, RankingExpression expression) { functions.put(name, expression); } + void requiredFunction(String name, TensorType type) { requiredFunctions.put(name, type); } /** * Returns all the output expressions of this indexed by name. The names consist of one or two parts @@ -111,39 +109,24 @@ public class ImportedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public List<Pair<String, ExpressionFunction>> outputExpressions() { - List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); + public List<Pair<String, RankingExpression>> outputExpressions() { + List<Pair<String, RankingExpression>> expressions = new ArrayList<>(); for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), - signatureEntry.getValue().outputExpression(outputEntry.getKey()) - .withName(signatureEntry.getKey() + "." + outputEntry.getKey()))); + expressions().get(outputEntry.getValue()))); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs expressions.add(new Pair<>(signatureEntry.getKey(), - new ExpressionFunction(signatureEntry.getKey(), - new ArrayList<>(signatureEntry.getValue().inputs().keySet()), - expressions().get(signatureEntry.getKey()), - signatureEntry.getValue().inputMap(), - Optional.empty()))); + expressions().get(signatureEntry.getKey()))); } if (signatures().isEmpty()) { // fallback for models without signatures if (expressions().size() == 1) { Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next(); - expressions.add(new Pair<>(singleEntry.getKey(), - new ExpressionFunction(singleEntry.getKey(), - new ArrayList<>(inputs.keySet()), - singleEntry.getValue(), - inputs, - Optional.empty()))); + expressions.add(new Pair<>(singleEntry.getKey(), singleEntry.getValue())); } else { for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { - expressions.add(new Pair<>(expressionEntry.getKey(), - new ExpressionFunction(expressionEntry.getKey(), - new ArrayList<>(inputs.keySet()), - expressionEntry.getValue(), - inputs, - Optional.empty()))); + expressions.add(new Pair<>(expressionEntry.getKey(), expressionEntry.getValue())); } } } @@ -151,7 +134,7 @@ public class ImportedModel { } /** - * A signature is a set of named inputs and outputs, where the inputs maps to input + * A signature is a set of named inputs and outputs, where the inputs maps to argument * ("placeholder") names+types, and outputs maps to expressions nodes. * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit * concept of signatures. For now, we handle ONNX models as having a single signature. @@ -159,8 +142,8 @@ public class ImportedModel { public class Signature { private final String name; - private final Map<String, String> inputs = new LinkedHashMap<>(); - private final Map<String, String> outputs = new LinkedHashMap<>(); + private final Map<String, String> inputs = new HashMap<>(); + private final Map<String, String> outputs = new HashMap<>(); private final Map<String, String> skippedOutputs = new HashMap<>(); private final List<String> importWarnings = new ArrayList<>(); @@ -175,20 +158,12 @@ public class ImportedModel { /** * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name - * in this signature to input name in the owning model + * to argument (Placeholder) name in the owner of this */ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } - /** Returns the name and type of all inputs in this signature as an immutable map */ - public Map<String, TensorType> inputMap() { - ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>(); - for (Map.Entry<String, String> inputEntry : inputs().entrySet()) - inputs.put(inputEntry.getKey(), owner().inputs().get(inputEntry.getValue())); - return inputs.build(); - } - - /** Returns the type of the input this input references */ - public TensorType inputArgument(String inputName) { return owner().inputs().get(inputs.get(inputName)); } + /** Returns the type of the argument this input references */ + public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); } /** Returns an immutable list of the expression names of this */ public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } @@ -205,13 +180,7 @@ public class ImportedModel { public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } /** Returns the expression this output references */ - public ExpressionFunction outputExpression(String outputName) { - return new ExpressionFunction(outputName, - new ArrayList<>(inputs.keySet()), - owner().expressions().get(outputs.get(outputName)), - inputMap(), - Optional.empty()); - } + public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } @Override public String toString() { return "signature '" + name + "'"; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index b7138ad87e3..d25502fd149 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -187,7 +187,8 @@ public abstract class ModelImporter { if (operation.isInput()) { // All inputs must have dimensions with standard naming convention: d0, d1, ... OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); - model.input(operation.vespaName(), standardNamingConvention.type()); + model.argument(operation.vespaName(), standardNamingConvention.type()); + model.requiredFunction(operation.vespaName(), standardNamingConvention.type()); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java index e6bb5f40b3f..917b0d6a389 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter; import onnx.Onnx; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index 94d663b4954..796c13a8669 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -3,8 +3,6 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.tensor.TensorType; import java.util.Collection; import java.util.Collections; @@ -82,14 +80,9 @@ public class SerializationContext extends FunctionReferenceContext { serializedFunctions.put(name, expressionString); } - /** Adds the serialization of the an argument type to a function */ - public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) { - serializedFunctions.put("rankingExpression(" + functionName + ")." + argumentName + ".type", type.toString()); - } - - /** Adds the serialization of the return type of a function */ - public void addFunctionTypeSerialization(String functionName, TensorType type) { - serializedFunctions.put("rankingExpression(" + functionName + ").type", type.toString()); + /** Returns the existing serialization of a function, or null if none */ + public String getFunctionSerialization(String name) { + return serializedFunctions.get(name); } @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java index 969bc318391..118eba2cd96 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java @@ -15,7 +15,6 @@ import org.junit.Test; import java.io.*; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.nio.charset.StandardCharsets; import java.util.Arrays; import static org.junit.Assert.fail; @@ -41,8 +40,7 @@ public class GroupingSerializationTest { t.assertMatch(new FloatResultNode(7.3)); t.assertMatch(new StringResultNode("7.3")); t.assertMatch(new StringResultNode( - new String(new byte[]{(byte)0xe5, (byte)0xa6, (byte)0x82, (byte)0xe6, (byte)0x9e, (byte)0x9c}, - StandardCharsets.UTF_8))); + new String(new byte[]{(byte)0xe5, (byte)0xa6, (byte)0x82, (byte)0xe6, (byte)0x9e, (byte)0x9c}))); t.assertMatch(new RawResultNode(new byte[]{'7', '.', '4'})); t.assertMatch(new IntegerBucketResultNode()); t.assertMatch(new FloatBucketResultNode()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java index 593e7b54c10..bf9684082f4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; @@ -21,11 +20,10 @@ public class BatchNormImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ExpressionFunction output = signature.outputExpression("y"); + RankingExpression output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName()); - model.assertEqualResult("X", output.getBody().getName()); - assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); + model.assertEqualResult("X", output.getName()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index 59712c0152f..a8f7542f3a4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -20,23 +19,22 @@ public class DropoutImportTestCase { TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved"); // Check required functions - assertEquals(1, model.get().inputs().size()); - assertTrue(model.get().inputs().containsKey("X")); + assertEquals(1, model.get().requiredFunctions().size()); + assertTrue(model.get().requiredFunctions().containsKey("X")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().inputs().get("X")); + model.get().requiredFunctions().get("X")); ImportedModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ExpressionFunction output = signature.outputExpression("y"); + RankingExpression output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("outputs/Maximum", output.getBody().getName()); + assertEquals("outputs/Maximum", output.getName()); assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", - output.getBody().getRoot().toString()); - model.assertEqualResult("X", output.getBody().getName()); - assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + output.getRoot().toString()); + model.assertEqualResult("X", output.getName()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java index 3d8d5d5a570..add66eece1a 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; @@ -21,10 +20,11 @@ public class MnistImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ExpressionFunction output = signature.outputExpression("y"); + RankingExpression output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("dnn/outputs/add", output.getBody().getName()); - model.assertEqualResultSum("input", output.getBody().getName(), 0.00001); + assertEquals("dnn/outputs/add", output.getName()); + model.assertEqualResultSum("input", output.getName(), 0.00001); } + } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index b6e83404ab1..e20ac16a691 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -1,6 +1,5 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; @@ -28,28 +27,27 @@ public class OnnxMnistSoftmaxImportTestCase { Tensor constant0 = model.largeConstants().get("test_Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), - constant0.type()); + constant0.type()); assertEquals(7840, constant0.size()); Tensor constant1 = model.largeConstants().get("test_Variable_1"); assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); + assertEquals(new TensorType.Builder().indexed("d1", 10).build(), + constant1.type()); assertEquals(10, constant1.size()); - // Check inputs - assertEquals(1, model.inputs().size()); - assertTrue(model.inputs().containsKey("Placeholder")); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder")); + // Check required functions (inputs) + assertEquals(1, model.requiredFunctions().size()); + assertTrue(model.requiredFunctions().containsKey("Placeholder")); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.requiredFunctions().get("Placeholder")); - // Check signature - ExpressionFunction output = model.defaultSignature().outputExpression("add"); + // Check outputs + RankingExpression output = model.defaultSignature().outputExpression("add"); assertNotNull(output); - assertEquals("add", output.getBody().getName()); + assertEquals("add", output.getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", - output.getBody().getRoot().toString()); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), - model.inputs().get(model.defaultSignature().inputs().get("Placeholder"))); - assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + output.getRoot().toString()); } @Test diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index 0a48ecfce21..ef28eb4678f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -39,10 +38,10 @@ public class TensorFlowMnistSoftmaxImportTestCase { assertEquals(0, model.get().functions().size()); // Check required functions - assertEquals(1, model.get().inputs().size()); - assertTrue(model.get().inputs().containsKey("Placeholder")); + assertEquals(1, model.get().requiredFunctions().size()); + assertTrue(model.get().requiredFunctions().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().inputs().get("Placeholder")); + model.get().requiredFunctions().get("Placeholder")); // Check signatures assertEquals(1, model.get().signatures().size()); @@ -57,12 +56,11 @@ public class TensorFlowMnistSoftmaxImportTestCase { // ... signature outputs assertEquals(1, signature.outputs().size()); - ExpressionFunction output = signature.outputExpression("y"); + RankingExpression output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("add", output.getBody().getName()); + assertEquals("add", output.getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", - output.getBody().getRoot().toString()); - assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + output.getRoot().toString()); // Test execution model.assertEqualResult("Placeholder", "MatMul"); |