aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-26 12:33:25 +0200
committerGitHub <noreply@github.com>2018-08-26 12:33:25 +0200
commite3f7e831b7ad19aa2ce309a8cdff970b0693d830 (patch)
tree97c8cca675690640e4b942f7b3d886127c233de2
parent727ae90506e72ed0a6695e2d7cb5c719f0152842 (diff)
parentd50393b429497756307b76d01e89a85270276f7a (diff)
Merge pull request #6675 from vespa-engine/bratseth/generate-rank-profiles-for-all-models-part-9
Bratseth/generate rank profiles for all models part 9
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java8
-rw-r--r--config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java14
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/path/Path.java39
6 files changed, 55 insertions, 31 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index b645af582e1..0911f567fa1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -494,12 +494,12 @@ public class ConvertedModel {
ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath());
if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap();
for (ApplicationFile expressionFile : expressionPath.listFiles()) {
- try {
+ try (Reader reader = new BufferedReader(expressionFile.createReader())){
String name = expressionFile.getPath().getName();
- expressions.put(name, new RankingExpression(name, expressionFile.createReader()));
+ expressions.put(name, new RankingExpression(name, reader));
}
- catch (FileNotFoundException e) {
- throw new IllegalStateException("Expression file removed while reading: " + expressionFile, e);
+ catch (IOException e) {
+ throw new UncheckedIOException("Failed reading " + expressionFile.getPath(), e);
}
catch (ParseException e) {
throw new IllegalStateException("Invalid stored expression in " + expressionFile, e);
diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
index 8eccc4e7d06..d06752c9b6d 100644
--- a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
@@ -32,35 +32,31 @@ public class ModelEvaluationTest {
RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
cluster.getConfig(b);
RankProfilesConfig config = new RankProfilesConfig(b);
- System.out.println(config.rankprofile(2).toString());
assertEquals(3, config.rankprofile().size());
Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet());
- assertTrue(modelNames.contains("xgboost_2_2_json"));
- assertTrue(modelNames.contains("mnist_softmax_onnx"));
+ assertTrue(modelNames.contains("xgboost_2_2"));
+ assertTrue(modelNames.contains("mnist_softmax"));
assertTrue(modelNames.contains("mnist_softmax_saved"));
ModelsEvaluator evaluator = new ModelsEvaluator(config);
assertEquals(3, evaluator.models().size());
- Model xgboost = evaluator.models().get("xgboost_2_2_json");
+ Model xgboost = evaluator.models().get("xgboost_2_2");
assertNotNull(xgboost);
assertNotNull(xgboost.evaluatorOf());
- assertNotNull(xgboost.evaluatorOf("xgboost_2_2_json"));
- System.out.println("xgboost functions: " + xgboost.functions().stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
+ assertNotNull(xgboost.evaluatorOf("xgboost_2_2"));
- Model onnx = evaluator.models().get("mnist_softmax_onnx");
+ Model onnx = evaluator.models().get("mnist_softmax");
assertNotNull(onnx);
assertNotNull(onnx.evaluatorOf());
assertNotNull(onnx.evaluatorOf("default"));
assertNotNull(onnx.evaluatorOf("default", "add"));
- System.out.println("onnx functions: " + onnx.functions().stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
Model tensorflow = evaluator.models().get("mnist_softmax_saved");
assertNotNull(tensorflow);
assertNotNull(tensorflow.evaluatorOf());
assertNotNull(tensorflow.evaluatorOf("serving_default"));
assertNotNull(tensorflow.evaluatorOf("serving_default", "y"));
- System.out.println("tensorflow functions: " + tensorflow.functions().stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index a96a3ce798b..815a01cdb99 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -31,7 +31,7 @@ public class RankingExpressionWithOnnxTestCase {
private final Path applicationDir = Path.fromString("src/test/integration/onnx/");
/** The model name */
- private final static String name = "mnist_softmax_onnx";
+ private final static String name = "mnist_softmax";
private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))";
@@ -54,7 +54,8 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testOnnxReferenceWithQueryFeature() {
String queryProfile = "<query-profile id='default' type='root'/>";
- String queryProfileType = "<query-profile-type id='root'>" +
+ String queryProfileType =
+ "<query-profile-type id='root'>" +
" <field name='query(mytensor)' type='tensor(d0[3],d1[784])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
@@ -89,7 +90,8 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testOnnxReferenceWithFeatureCombination() {
String queryProfile = "<query-profile id='default' type='root'/>";
- String queryProfileType = "<query-profile-type id='root'>" +
+ String queryProfileType =
+ "<query-profile-type id='root'>" +
" <field name='query(mytensor)' type='tensor(d0[3],d1[784],d2[10])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
index 827b1911369..b1714b49256 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
@@ -6,9 +6,7 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.path.Path;
import java.io.File;
-import java.util.ArrayList;
import java.util.Collection;
-import java.util.List;
import java.util.Optional;
/**
@@ -71,11 +69,20 @@ public class ImportedModels {
return importedModels.values();
}
- private static String toName(File modelPath) {
- String localPath = concatenateAfterModelsDirectory(Path.fromString(modelPath.toString()));
+ private static String toName(File modelFile) {
+ Path modelPath = Path.fromString(modelFile.toString());
+ if (modelFile.isFile())
+ modelPath = stripFileEnding(modelPath);
+ String localPath = concatenateAfterModelsDirectory(modelPath);
return localPath.replace('.', '_');
}
+ private static Path stripFileEnding(Path path) {
+ int dotIndex = path.last().lastIndexOf(".");
+ if (dotIndex <= 0) return path;
+ return path.withLast(path.last().substring(0, dotIndex));
+ }
+
private static String concatenateAfterModelsDirectory(Path path) {
boolean afterModels = false;
StringBuilder result = new StringBuilder();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
index 41817eb3e62..13718935cef 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -156,7 +156,7 @@ public abstract class ModelImporter {
private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) {
if (operation.function().isPresent()) {
String name = operation.name();
- if (!model.expressions().containsKey(name)) {
+ if ( ! model.expressions().containsKey(name)) {
TensorFunction function = operation.function().get();
if (isSignatureOutput(model, operation)) {
diff --git a/vespajlib/src/main/java/com/yahoo/path/Path.java b/vespajlib/src/main/java/com/yahoo/path/Path.java
index 85c22c1f29b..135cea06906 100644
--- a/vespajlib/src/main/java/com/yahoo/path/Path.java
+++ b/vespajlib/src/main/java/com/yahoo/path/Path.java
@@ -15,6 +15,7 @@ import java.util.stream.Collectors;
* Represents a path represented by a list of elements. Immutable
*
* @author Ulf Lilleengen
+ * @author bratseth
*/
@Beta
public final class Path {
@@ -22,23 +23,23 @@ public final class Path {
private final String delimiter;
private final ImmutableList<String> elements;
- /**
- * Create an empty path.
- */
+ /** Creates an empty path */
private Path(String delimiter) {
this(new ArrayList<>(), delimiter);
}
/**
- * Create a new path as a copy of the provided path.
- * @param rhs the path to copy.
+ * Create a new path as a copy of the provided path
+ *
+ * @param path the path to copy
*/
- private Path(Path rhs) {
- this(rhs.elements, rhs.delimiter);
+ private Path(Path path) {
+ this(path.elements, path.delimiter);
}
/**
- * Create path with given elements.
+ * Create path with given elements
+ *
* @param elements a list of path elements
*/
private Path(List<String> elements, String delimiter) {
@@ -74,8 +75,8 @@ public final class Path {
* Appends a path to another path, thereby creating a new path with the provided path
* appended to this.
*
- * @param path the path to append.
- * @return a new path with argument appended to it.
+ * @param path the path to append
+ * @return a new path with argument appended to it
*/
public Path append(Path path) {
List<String> newElements = new ArrayList<>(this.elements);
@@ -125,6 +126,24 @@ public final class Path {
return new Path(childElements, delimiter);
}
+ /** Returns the last element in this, or the empty string if this path is empty */
+ public String last() {
+ if (elements.isEmpty()) return "";
+ return elements.get(elements.size() - 1);
+ }
+
+ /**
+ * Returns a new path with the last element replaced by the given element.
+ *
+ * @throws IllegalStateException if this path is empty
+ */
+ public Path withLast(String element) {
+ if (element.isEmpty()) throw new IllegalStateException("Cannot set the last element of an empty path");
+ List<String> newElements = new ArrayList<>(elements);
+ newElements.set(newElements.size() -1, element);
+ return new Path(newElements, delimiter);
+ }
+
/** Returns a string representation of this path where the delimiter is prepended */
public String getAbsolute() {
return delimiter + getRelative();