diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-11-29 07:24:41 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-11-29 07:24:41 -0800 |
commit | f53959bddcd482b262276c4d3ae75fa754f82394 (patch) | |
tree | 3557fb25956d8b65273d77c1903cb8097c06bb1a /model-integration | |
parent | 71eccb08d9b88fb4b5485baed3fdc2cc17822a6f (diff) |
Revert "Revert "Always output tensor type in toString""
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java | 35 |
1 files changed, 10 insertions, 25 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index ce91b8bb141..d7ac8bc90b2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -72,7 +72,7 @@ public class ImportedModel implements ImportedMlModel { * These should have sizes up to a few kb at most, and correspond to constant values given in the source model. */ @Override - public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); } + public Map<String, String> smallConstants() { return asStrings(smallConstants); } boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); } @@ -82,7 +82,7 @@ public class ImportedModel implements ImportedMlModel { * For TensorFlow this corresponds to Variable files stored separately. */ @Override - public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); } + public Map<String, String> largeConstants() { return asStrings(largeConstants); } boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); } @@ -133,7 +133,7 @@ public class ImportedModel implements ImportedMlModel { functions.add(new ImportedMlFunction(signatureEntry.getKey(), new ArrayList<>(signatureEntry.getValue().inputs().values()), expressions().get(signatureEntry.getKey()).getRoot().toString(), - asTensorTypeStrings(signatureEntry.getValue().inputMap()), + asStrings(signatureEntry.getValue().inputMap()), Optional.empty())); } if (signatures().isEmpty()) { // fallback for models without signatures @@ -142,7 +142,7 @@ public class ImportedModel implements ImportedMlModel { functions.add(new ImportedMlFunction(singleEntry.getKey(), new ArrayList<>(inputs.keySet()), singleEntry.getValue().getRoot().toString(), - asTensorTypeStrings(inputs), + asStrings(inputs), Optional.empty())); } else { @@ -150,7 +150,7 @@ public class ImportedModel implements ImportedMlModel { functions.add(new ImportedMlFunction(expressionEntry.getKey(), new ArrayList<>(inputs.keySet()), expressionEntry.getValue().getRoot().toString(), - asTensorTypeStrings(inputs), + asStrings(inputs), Optional.empty())); } } @@ -158,26 +158,13 @@ public class ImportedModel implements ImportedMlModel { return functions; } - private Map<String, String> asTensorStrings(Map<String, Tensor> map) { + private Map<String, String> asStrings(Map<String, ?> map) { HashMap<String, String> values = new HashMap<>(); - for (Map.Entry<String, Tensor> entry : map.entrySet()) { - Tensor tensor = entry.getValue(); - // TODO: See Tensor.toStandardString - if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) - values.put(entry.getKey(), tensor.toString()); - else - values.put(entry.getKey(), tensor.type() + ":" + tensor); - } + for (Map.Entry<String, ?> entry : map.entrySet()) + values.put(entry.getKey(), entry.getValue().toString()); return values; } - private static Map<String, String> asTensorTypeStrings(Map<String, TensorType> map) { - Map<String, String> stringMap = new HashMap<>(); - for (Map.Entry<String, TensorType> entry : map.entrySet()) - stringMap.put(entry.getKey(), entry.getValue().toString()); - return stringMap; - } - private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) { HashMap<String, String> values = new HashMap<>(); for (Map.Entry<String, RankingExpression> entry : map.entrySet()) @@ -237,9 +224,7 @@ public class ImportedModel implements ImportedMlModel { */ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } - /** - * Returns an immutable list of possibly non-fatal warnings encountered during import. - */ + /** Returns an immutable list of possibly non-fatal warnings encountered during import. */ public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } /** Returns the expression this output references as an imported function */ @@ -247,7 +232,7 @@ public class ImportedModel implements ImportedMlModel { return new ImportedMlFunction(functionName, new ArrayList<>(inputs.values()), owner().expressions().get(outputs.get(outputName)).getRoot().toString(), - asTensorTypeStrings(inputMap()), + asStrings(inputMap()), Optional.empty()); } |