diff options
author | HÃ¥kon Hallingstad <hakon@oath.com> | 2018-11-29 12:27:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-11-29 12:27:41 +0100 |
commit | 2a3dfcc9b50fc16dd9551958d9614cc9ff0d5be3 (patch) | |
tree | 59e949997bc21749f184e6bede41005cc4173f24 /model-integration | |
parent | aabdc58047ccb7597ccbe54ee35858792116197f (diff) |
Revert "Always output tensor type in toString"
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index d7ac8bc90b2..ce91b8bb141 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -72,7 +72,7 @@ public class ImportedModel implements ImportedMlModel { * These should have sizes up to a few kb at most, and correspond to constant values given in the source model. */ @Override - public Map<String, String> smallConstants() { return asStrings(smallConstants); } + public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); } boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); } @@ -82,7 +82,7 @@ public class ImportedModel implements ImportedMlModel { * For TensorFlow this corresponds to Variable files stored separately. */ @Override - public Map<String, String> largeConstants() { return asStrings(largeConstants); } + public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); } boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); } @@ -133,7 +133,7 @@ public class ImportedModel implements ImportedMlModel { functions.add(new ImportedMlFunction(signatureEntry.getKey(), new ArrayList<>(signatureEntry.getValue().inputs().values()), expressions().get(signatureEntry.getKey()).getRoot().toString(), - asStrings(signatureEntry.getValue().inputMap()), + asTensorTypeStrings(signatureEntry.getValue().inputMap()), Optional.empty())); } if (signatures().isEmpty()) { // fallback for models without signatures @@ -142,7 +142,7 @@ public class ImportedModel implements ImportedMlModel { functions.add(new ImportedMlFunction(singleEntry.getKey(), new ArrayList<>(inputs.keySet()), singleEntry.getValue().getRoot().toString(), - asStrings(inputs), + asTensorTypeStrings(inputs), Optional.empty())); } else { @@ -150,7 +150,7 @@ public class ImportedModel implements ImportedMlModel { functions.add(new ImportedMlFunction(expressionEntry.getKey(), new ArrayList<>(inputs.keySet()), expressionEntry.getValue().getRoot().toString(), - asStrings(inputs), + asTensorTypeStrings(inputs), Optional.empty())); } } @@ -158,13 +158,26 @@ public class ImportedModel implements ImportedMlModel { return functions; } - private Map<String, String> asStrings(Map<String, ?> map) { + private Map<String, String> asTensorStrings(Map<String, Tensor> map) { HashMap<String, String> values = new HashMap<>(); - for (Map.Entry<String, ?> entry : map.entrySet()) - values.put(entry.getKey(), entry.getValue().toString()); + for (Map.Entry<String, Tensor> entry : map.entrySet()) { + Tensor tensor = entry.getValue(); + // TODO: See Tensor.toStandardString + if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) + values.put(entry.getKey(), tensor.toString()); + else + values.put(entry.getKey(), tensor.type() + ":" + tensor); + } return values; } + private static Map<String, String> asTensorTypeStrings(Map<String, TensorType> map) { + Map<String, String> stringMap = new HashMap<>(); + for (Map.Entry<String, TensorType> entry : map.entrySet()) + stringMap.put(entry.getKey(), entry.getValue().toString()); + return stringMap; + } + private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) { HashMap<String, String> values = new HashMap<>(); for (Map.Entry<String, RankingExpression> entry : map.entrySet()) @@ -224,7 +237,9 @@ public class ImportedModel implements ImportedMlModel { */ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } - /** Returns an immutable list of possibly non-fatal warnings encountered during import. */ + /** + * Returns an immutable list of possibly non-fatal warnings encountered during import. + */ public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } /** Returns the expression this output references as an imported function */ @@ -232,7 +247,7 @@ public class ImportedModel implements ImportedMlModel { return new ImportedMlFunction(functionName, new ArrayList<>(inputs.values()), owner().expressions().get(outputs.get(outputName)).getRoot().toString(), - asStrings(inputMap()), + asTensorTypeStrings(inputMap()), Optional.empty()); } |