summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-06 16:38:28 +0200
committerJon Bratseth <bratseth@oath.com>2018-09-06 16:38:28 +0200
commitef211a55b4a343ad8bcd8ae34a202f3c61828a7a (patch)
tree4f62a9363a9b48bd8875e6868fc9f974d37b9b5b
parentc1fdecf3cb26f1a3aef2caf290916a4f533c6c58 (diff)
Send global constants
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java15
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java6
-rwxr-xr-xconfig-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java1
-rw-r--r--config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java32
-rwxr-xr-xdocumentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java2
-rw-r--r--documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java2
-rw-r--r--linguistics/src/main/java/com/yahoo/language/detect/Detection.java2
-rw-r--r--linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java2
-rw-r--r--linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java1
-rwxr-xr-xmessagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java28
-rwxr-xr-xmessagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java5
-rwxr-xr-xmessagebus/src/main/java/com/yahoo/messagebus/routing/Route.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java38
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java66
15 files changed, 133 insertions, 71 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java
index 164cb7f808e..5ac1418c0c7 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java
@@ -1,6 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition;
+import com.yahoo.config.FileReference;
+import com.yahoo.vespa.model.AbstractService;
+import com.yahoo.vespa.model.utils.FileSender;
+
+import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -33,4 +38,14 @@ public class RankingConstants {
return Collections.unmodifiableMap(constants);
}
+ /** Initiate sending of these constants to some services over file distribution */
+ public void sendTo(Collection<? extends AbstractService> services) {
+ for (RankingConstant constant : constants.values()) {
+ FileReference reference = (constant.getPathType() == RankingConstant.PathType.FILE)
+ ? FileSender.sendFileToServices(constant.getFileName(), services)
+ : FileSender.sendUriToServices(constant.getUri(), services);
+ constant.setFileReference(reference.value());
+ }
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index e58b3da4f72..fcbfb47c597 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -10,7 +10,9 @@ import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.Search;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.model.AbstractService;
+import java.util.Collection;
import java.util.Map;
import java.util.logging.Logger;
@@ -79,6 +81,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
return rankProfiles.get(name);
}
+ public void sendConstantsTo(Collection<? extends AbstractService> services) {
+ rankingConstants.sendTo(services);
+ }
+
@Override
public String getDerivedName() { return "rank-profiles"; }
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java
index 0b434fd0c49..fbe86d26b02 100755
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java
@@ -87,7 +87,6 @@ import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
-import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -361,6 +360,7 @@ public final class ContainerCluster
public void prepare() {
addAndSendApplicationBundles();
+ rankProfileList.sendConstantsTo(containers);
sendUserConfiguredFiles();
setApplicationMetaData();
for (RestApi restApi : restApiGroup.getComponents())
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java
index 19d014e0a1d..ceb48732116 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java
@@ -15,6 +15,7 @@ import java.util.*;
/**
* Config producer for the FederationSearcher.
+ *
* @author Tony Vaagenes
*/
public class FederationSearcher extends Searcher<FederationSearcherModel> implements FederationConfig.Producer {
diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
index 91d7fd436f3..a16c32d47ab 100644
--- a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
@@ -70,31 +70,41 @@ public class ModelEvaluationTest {
assertEquals(4, config.rankprofile().size());
Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet());
assertTrue(modelNames.contains("xgboost_2_2"));
+ assertTrue(modelNames.contains("mnist_saved"));
assertTrue(modelNames.contains("mnist_softmax"));
assertTrue(modelNames.contains("mnist_softmax_saved"));
ModelsEvaluator evaluator = new ModelsEvaluator(config, constantsConfig);
assertEquals(4, evaluator.models().size());
+
Model xgboost = evaluator.models().get("xgboost_2_2");
assertNotNull(xgboost);
assertNotNull(xgboost.evaluatorOf());
assertNotNull(xgboost.evaluatorOf("xgboost_2_2"));
- Model onnx = evaluator.models().get("mnist_softmax");
- assertNotNull(onnx);
- assertNotNull(onnx.evaluatorOf());
- assertNotNull(onnx.evaluatorOf("default"));
- assertNotNull(onnx.evaluatorOf("default", "add"));
- assertNotNull(onnx.evaluatorOf("default.add"));
+ Model tensorflow_mnist = evaluator.models().get("mnist_saved");
+ assertNotNull(tensorflow_mnist);
+ assertNotNull(tensorflow_mnist.evaluatorOf("serving_default"));
+ assertNotNull(tensorflow_mnist.evaluatorOf("serving_default", "y"));
+ assertNotNull(tensorflow_mnist.evaluatorOf("serving_default.y"));
+ assertNotNull(evaluator.evaluatorOf("mnist_saved", "serving_default.y"));
+ assertNotNull(evaluator.evaluatorOf("mnist_saved", "serving_default", "y"));
+
+ Model onnx_mnist_softmax = evaluator.models().get("mnist_softmax");
+ assertNotNull(onnx_mnist_softmax);
+ assertNotNull(onnx_mnist_softmax.evaluatorOf());
+ assertNotNull(onnx_mnist_softmax.evaluatorOf("default"));
+ assertNotNull(onnx_mnist_softmax.evaluatorOf("default", "add"));
+ assertNotNull(onnx_mnist_softmax.evaluatorOf("default.add"));
assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default.add"));
assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default", "add"));
- Model tensorflow = evaluator.models().get("mnist_softmax_saved");
- assertNotNull(tensorflow);
- assertNotNull(tensorflow.evaluatorOf());
- assertNotNull(tensorflow.evaluatorOf("serving_default"));
- assertNotNull(tensorflow.evaluatorOf("serving_default", "y"));
+ Model tensorflow_mnist_softmax = evaluator.models().get("mnist_softmax_saved");
+ assertNotNull(tensorflow_mnist_softmax);
+ assertNotNull(tensorflow_mnist_softmax.evaluatorOf());
+ assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default"));
+ assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default", "y"));
}
}
diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java
index 7423792693b..dbf68106e07 100755
--- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java
+++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java
@@ -20,7 +20,7 @@ import java.util.List;
public class ANDPolicy implements DocumentProtocolRoutingPolicy {
// A list of hops that are to always be selected when select() is invoked.
- private final List<Hop> hops = new ArrayList<Hop>();
+ private final List<Hop> hops = new ArrayList<>();
/**
* Constructs a new AND policy that requires all recipients to be ok for it to merge their replies to an ok reply.
diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java
index 82679e17990..a5b3accac68 100644
--- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java
+++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java
@@ -51,7 +51,7 @@ public class MessageTypePolicy implements DocumentProtocolRoutingPolicy, ConfigS
@Override
public void configure(MessagetyperouteselectorpolicyConfig cfg) {
- Map<Integer, Route> h = new HashMap<Integer, Route>();
+ Map<Integer, Route> h = new HashMap<>();
for (MessagetyperouteselectorpolicyConfig.Route selector : cfg.route()) {
h.put(selector.messagetype(), Route.parse(selector.name()));
}
diff --git a/linguistics/src/main/java/com/yahoo/language/detect/Detection.java b/linguistics/src/main/java/com/yahoo/language/detect/Detection.java
index c08bdc14cfb..4b816335154 100644
--- a/linguistics/src/main/java/com/yahoo/language/detect/Detection.java
+++ b/linguistics/src/main/java/com/yahoo/language/detect/Detection.java
@@ -7,7 +7,7 @@ import java.nio.charset.Charset;
import java.nio.charset.UnsupportedCharsetException;
/**
- * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
+ * @author Einar M R Rosenvinge
*/
public class Detection {
diff --git a/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java b/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java
index 12de309a2d3..7451a7f2c9c 100644
--- a/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java
+++ b/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java
@@ -4,8 +4,10 @@ import com.yahoo.language.process.Tokenizer;
import com.yahoo.language.simple.SimpleLinguistics;
public class OpenNlpLinguistics extends SimpleLinguistics {
+
@Override
public Tokenizer getTokenizer() {
return new OpenNlpTokenizer(getNormalizer(), getTransformer());
}
+
}
diff --git a/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java b/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java
index 2b31f95675b..f29ec691c60 100644
--- a/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java
+++ b/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java
@@ -37,6 +37,7 @@ import java.util.Locale;
* @author bjorncs
*/
public class SimpleDetector implements Detector {
+
static private TextObjectFactory textObjectFactory;
static private LanguageDetector languageDetector;
diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java
index 63514eca6dd..e21aeef1ee2 100755
--- a/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java
+++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java
@@ -20,15 +20,15 @@ public class Hop {
private String cache = null;
/**
- * <p>Constructs an empty hop. You will need to add directives to the
- * selector to make this usable.</p>
+ * Constructs an empty hop. You will need to add directives to the
+ * selector to make this usable.
*/
public Hop() {
// empty
}
/**
- * <p>Implements the copy constructor.</p>
+ * Implements the copy constructor.
*
* @param hop The hop to copy.
*/
@@ -38,8 +38,8 @@ public class Hop {
}
/**
- * <p>Constructs a fully populated hop. This is package private and used by
- * the {@link HopBlueprint#create()} method.</p>
+ * Constructs a fully populated hop. This is package private and used by
+ * the {@link HopBlueprint#create()} method.
*
* @param selector The selector to copy.
* @param ignoreResult Whether or not to ignore the result of this hop.
@@ -50,8 +50,8 @@ public class Hop {
}
/**
- * <p>Parses the given string as a single hop. The {@link #toString()}
- * method is compatible with this parser.</p>
+ * Parses the given string as a single hop. The {@link #toString()}
+ * method is compatible with this parser.
*
* @param str The string to parse.
* @return A hop that corresponds to the string.
@@ -65,8 +65,7 @@ public class Hop {
}
/**
- * <p>Returns whether or not there are any directives contained in this
- * hop.</p>
+ * Returns whether or not there are any directives contained in this hop.
*
* @return True if there is at least one directive.
*/
@@ -75,7 +74,7 @@ public class Hop {
}
/**
- * <p>Returns the number of directives contained in this hop.</p>
+ * Returns the number of directives contained in this hop.
*
* @return The number of directives.
*/
@@ -84,7 +83,7 @@ public class Hop {
}
/**
- * <p>Returns the directive at the given index.</p>
+ * Returns the directive at the given index.
*
* @param i The index of the directive to return.
* @return The item.
@@ -94,7 +93,7 @@ public class Hop {
}
/**
- * <p>Adds a new directive to this hop.</p>
+ * Adds a new directive to this hop.
*
* @param directive The directive to add.
* @return This, to allow chaining.
@@ -106,7 +105,7 @@ public class Hop {
}
/**
- * <p>Sets the directive at a given index.</p>
+ * Sets the directive at a given index.
*
* @param i The index at which to set the directive.
* @param directive The directive to set.
@@ -283,9 +282,10 @@ public class Hop {
@Override
public int hashCode() {
- int result = selector != null ? selector.hashCode() : 0;
+ int result = selector.hashCode();
result = 31 * result + (ignoreResult ? 1 : 0);
result = 31 * result + (cache != null ? cache.hashCode() : 0);
return result;
}
+
}
diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java
index 809b2da69c4..838b11e7a02 100755
--- a/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java
+++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java
@@ -14,13 +14,14 @@ public interface HopDirective {
* @param dir The directive to compare this to.
* @return True if this matches the argument.
*/
- public boolean matches(HopDirective dir);
+ boolean matches(HopDirective dir);
/**
* Returns a string representation of this that can be debugged but not parsed.
*
* @return The debug string.
*/
- public String toDebugString();
+ String toDebugString();
+
}
diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java
index a07c6e16100..9190b680ebf 100755
--- a/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java
+++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java
@@ -20,7 +20,7 @@ import java.util.List;
*/
public class Route {
- private final List<Hop> hops = new ArrayList<Hop>();
+ private final List<Hop> hops = new ArrayList<>();
private String cache = null;
/**
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index b9e7a27c013..98c80ace047 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -26,6 +26,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.logging.Logger;
/**
* Converts RankProfilesConfig instances to RankingExpressions for evaluation.
@@ -35,19 +36,13 @@ import java.util.Optional;
*/
class RankProfilesConfigImporter {
- /**
- * Constants already imported in this while reading some expression.
- * This is to avoid re-reading constants referenced
- * multiple places, as that is potentially costly.
- */
- private Map<String, Constant> globalImportedConstants = new HashMap<>();
+ private static final Logger log = Logger.getLogger("CONSTANTS");
/**
* Returns a map of the models contained in this config, indexed on name.
* The map is modifiable and owned by the caller.
*/
Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) {
- globalImportedConstants.clear();
try {
Map<String, Model> models = new HashMap<>();
for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
@@ -61,7 +56,8 @@ class RankProfilesConfigImporter {
}
}
- private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) throws ParseException {
+ private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig)
+ throws ParseException {
List<ExpressionFunction> functions = new ArrayList<>();
Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>();
ExpressionFunction firstPhase = null;
@@ -79,7 +75,8 @@ class RankProfilesConfigImporter {
functions.add(new ExpressionFunction(reference.get().functionName(), arguments, expression)); //
// Make all functions, bound or not available under the name they are referenced by in expressions
- referencedFunctions.put(reference.get(), new ExpressionFunction(reference.get().serialForm(), arguments, expression));
+ referencedFunctions.put(reference.get(),
+ new ExpressionFunction(reference.get().serialForm(), arguments, expression));
}
else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to macros
firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(),
@@ -112,24 +109,29 @@ class RankProfilesConfigImporter {
private List<Constant> readConstants(RankingConstantsConfig constantsConfig) {
List<Constant> constants = new ArrayList<>();
+
for (RankingConstantsConfig.Constant constantConfig : constantsConfig.constant()) {
constants.add(new Constant(constantConfig.name(),
- readTensorFromFile(TensorType.fromSpec(constantConfig.type()),
+ readTensorFromFile(constantConfig.name(),
+ TensorType.fromSpec(constantConfig.type()),
constantConfig.fileref().value())));
}
return constants;
}
- private Tensor readTensorFromFile(TensorType type, String fileName) {
+ private Tensor readTensorFromFile(String name, TensorType type, String fileReference) {
try {
- if (fileName.endsWith(".tbf"))
- return TypedBinaryFormat.decode(Optional.of(type),
- GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileName))));
- // TODO: Support json and json.lz4
-
- if (fileName.isEmpty()) // this is the case in unit tests
+ if (fileReference.isEmpty()) { // this may be the case in unit tests
+ log.warning("Got empty file reference for constant '" + name + "', using an empty tensor");
return Tensor.from(type, "{}");
- throw new IllegalArgumentException("Unknown tensor file format (determined by file ending): " + fileName);
+ }
+ if ( ! new File(fileReference).exists()) { // this may be the case in unit tests
+ log.warning("Got empty file reference for constant '" + name + "', using an empty tensor");
+ return Tensor.from(type, "{}");
+ }
+ return TypedBinaryFormat.decode(Optional.of(type),
+ GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileReference))));
+ // TODO: Support json and json.lz4
}
catch (IOException e) {
throw new UncheckedIOException(e);
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
index 84e01e58280..2cb9602dfa7 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
@@ -11,9 +11,11 @@ import org.junit.Test;
import java.io.File;
import java.util.Map;
+import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
/**
* Tests instantiating models from rank-profiles configs.
@@ -28,27 +30,49 @@ public class RankProfilesImporterTest {
assertEquals(4, models.size());
- Model xgboost = models.get("xgboost_2_2");
- assertFunction("xgboost_2_2",
- "(optimized sum of condition trees of size 192 bytes)",
- xgboost);
-
- Model onnxMnistSoftmax = models.get("mnist_softmax");
- assertFunction("default.add",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
- onnxMnistSoftmax);
- assertEquals("tensor(d1[10],d2[784])",
- onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
-
- Model tfMnistSoftmax = models.get("mnist_softmax_saved");
- assertFunction("serving_default.y",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
- tfMnistSoftmax);
-
- Model tfMnist = models.get("mnist_saved");
- assertFunction("serving_default.y",
- "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
- tfMnist);
+ // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that
+ {
+ Model xgboost = models.get("xgboost_2_2");
+ assertFunction("xgboost_2_2",
+ "(optimized sum of condition trees of size 192 bytes)",
+ xgboost);
+ FunctionEvaluator evaluator = xgboost.evaluatorOf();
+ assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+
+ {
+
+ Model onnxMnistSoftmax = models.get("mnist_softmax");
+ assertFunction("default.add",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
+ onnxMnistSoftmax);
+ assertEquals("tensor(d1[10],d2[784])",
+ onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
+ FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available
+ assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+
+ {
+ Model tfMnistSoftmax = models.get("mnist_softmax_saved");
+ assertFunction("serving_default.y",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
+ tfMnistSoftmax);
+ FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
+ assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+
+ {
+ Model tfMnist = models.get("mnist_saved");
+ assertFunction("serving_default.y",
+ "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ // Macro:
+ assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add",
+ "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument
+ assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
}
@Test