diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-11-28 13:02:02 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-11-28 13:02:02 +0100 |
commit | 7cbf6c75962e44f76c64c2442b45c83fe275fdcb (patch) | |
tree | 608e5d93e62e55c24bddca35f874d5100466e214 /model-integration | |
parent | e12e2d54042b2aeca632ee630f0d67695dfb2f1b (diff) |
Always output tensor type in toString
This allows us to always restore a tensor accurately from
its toString form.
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java | 35 |
1 files changed, 10 insertions, 25 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 0c5866b87fa..9461c391951 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -72,7 +72,7 @@ public class ImportedModel implements ImportedMlModel { * These should have sizes up to a few kb at most, and correspond to constant values given in the source model. */ @Override - public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); } + public Map<String, String> smallConstants() { return asStrings(smallConstants); } boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); } @@ -82,7 +82,7 @@ public class ImportedModel implements ImportedMlModel { * For TensorFlow this corresponds to Variable files stored separately. */ @Override - public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); } + public Map<String, String> largeConstants() { return asStrings(largeConstants); } boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); } @@ -136,7 +136,7 @@ public class ImportedModel implements ImportedMlModel { functions.add(new ImportedMlFunction(signatureEntry.getKey(), new ArrayList<>(signatureEntry.getValue().inputs().values()), expressions().get(signatureEntry.getKey()).getRoot().toString(), - asTensorTypeStrings(signatureEntry.getValue().inputMap()), + asStrings(signatureEntry.getValue().inputMap()), Optional.empty())); } if (signatures().isEmpty()) { // fallback for models without signatures @@ -145,7 +145,7 @@ public class ImportedModel implements ImportedMlModel { functions.add(new ImportedMlFunction(singleEntry.getKey(), new ArrayList<>(inputs.keySet()), singleEntry.getValue().getRoot().toString(), - asTensorTypeStrings(inputs), + asStrings(inputs), Optional.empty())); } else { @@ -153,7 +153,7 @@ public class ImportedModel implements ImportedMlModel { functions.add(new ImportedMlFunction(expressionEntry.getKey(), new ArrayList<>(inputs.keySet()), expressionEntry.getValue().getRoot().toString(), - asTensorTypeStrings(inputs), + asStrings(inputs), Optional.empty())); } } @@ -161,26 +161,13 @@ public class ImportedModel implements ImportedMlModel { return functions; } - private Map<String, String> asTensorStrings(Map<String, Tensor> map) { + private Map<String, String> asStrings(Map<String, ?> map) { HashMap<String, String> values = new HashMap<>(); - for (Map.Entry<String, Tensor> entry : map.entrySet()) { - Tensor tensor = entry.getValue(); - // TODO: See Tensor.toStandardString - if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) - values.put(entry.getKey(), tensor.toString()); - else - values.put(entry.getKey(), tensor.type() + ":" + tensor); - } + for (Map.Entry<String, ?> entry : map.entrySet()) + values.put(entry.getKey(), entry.getValue().toString()); return values; } - private static Map<String, String> asTensorTypeStrings(Map<String, TensorType> map) { - Map<String, String> stringMap = new HashMap<>(); - for (Map.Entry<String, TensorType> entry : map.entrySet()) - stringMap.put(entry.getKey(), entry.getValue().toString()); - return stringMap; - } - private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) { HashMap<String, String> values = new HashMap<>(); for (Map.Entry<String, RankingExpression> entry : map.entrySet()) @@ -240,9 +227,7 @@ public class ImportedModel implements ImportedMlModel { */ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } - /** - * Returns an immutable list of possibly non-fatal warnings encountered during import. - */ + /** Returns an immutable list of possibly non-fatal warnings encountered during import. */ public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } /** Returns the expression this output references */ @@ -259,7 +244,7 @@ public class ImportedModel implements ImportedMlModel { return new ImportedMlFunction(functionName, new ArrayList<>(inputs.values()), owner().expressions().get(outputs.get(outputName)).getRoot().toString(), - asTensorTypeStrings(inputMap()), + asStrings(inputMap()), Optional.empty()); } |