summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorHÃ¥kon Hallingstad <hakon@oath.com>2018-11-29 12:27:41 +0100
committerGitHub <noreply@github.com>2018-11-29 12:27:41 +0100
commit2a3dfcc9b50fc16dd9551958d9614cc9ff0d5be3 (patch)
tree59e949997bc21749f184e6bede41005cc4173f24 /model-integration
parentaabdc58047ccb7597ccbe54ee35858792116197f (diff)
Revert "Always output tensor type in toString"
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java35
1 files changed, 25 insertions, 10 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index d7ac8bc90b2..ce91b8bb141 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -72,7 +72,7 @@ public class ImportedModel implements ImportedMlModel {
* These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
*/
@Override
- public Map<String, String> smallConstants() { return asStrings(smallConstants); }
+ public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); }
boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); }
@@ -82,7 +82,7 @@ public class ImportedModel implements ImportedMlModel {
* For TensorFlow this corresponds to Variable files stored separately.
*/
@Override
- public Map<String, String> largeConstants() { return asStrings(largeConstants); }
+ public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); }
boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); }
@@ -133,7 +133,7 @@ public class ImportedModel implements ImportedMlModel {
functions.add(new ImportedMlFunction(signatureEntry.getKey(),
new ArrayList<>(signatureEntry.getValue().inputs().values()),
expressions().get(signatureEntry.getKey()).getRoot().toString(),
- asStrings(signatureEntry.getValue().inputMap()),
+ asTensorTypeStrings(signatureEntry.getValue().inputMap()),
Optional.empty()));
}
if (signatures().isEmpty()) { // fallback for models without signatures
@@ -142,7 +142,7 @@ public class ImportedModel implements ImportedMlModel {
functions.add(new ImportedMlFunction(singleEntry.getKey(),
new ArrayList<>(inputs.keySet()),
singleEntry.getValue().getRoot().toString(),
- asStrings(inputs),
+ asTensorTypeStrings(inputs),
Optional.empty()));
}
else {
@@ -150,7 +150,7 @@ public class ImportedModel implements ImportedMlModel {
functions.add(new ImportedMlFunction(expressionEntry.getKey(),
new ArrayList<>(inputs.keySet()),
expressionEntry.getValue().getRoot().toString(),
- asStrings(inputs),
+ asTensorTypeStrings(inputs),
Optional.empty()));
}
}
@@ -158,13 +158,26 @@ public class ImportedModel implements ImportedMlModel {
return functions;
}
- private Map<String, String> asStrings(Map<String, ?> map) {
+ private Map<String, String> asTensorStrings(Map<String, Tensor> map) {
HashMap<String, String> values = new HashMap<>();
- for (Map.Entry<String, ?> entry : map.entrySet())
- values.put(entry.getKey(), entry.getValue().toString());
+ for (Map.Entry<String, Tensor> entry : map.entrySet()) {
+ Tensor tensor = entry.getValue();
+ // TODO: See Tensor.toStandardString
+ if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty())
+ values.put(entry.getKey(), tensor.toString());
+ else
+ values.put(entry.getKey(), tensor.type() + ":" + tensor);
+ }
return values;
}
+ private static Map<String, String> asTensorTypeStrings(Map<String, TensorType> map) {
+ Map<String, String> stringMap = new HashMap<>();
+ for (Map.Entry<String, TensorType> entry : map.entrySet())
+ stringMap.put(entry.getKey(), entry.getValue().toString());
+ return stringMap;
+ }
+
private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) {
HashMap<String, String> values = new HashMap<>();
for (Map.Entry<String, RankingExpression> entry : map.entrySet())
@@ -224,7 +237,9 @@ public class ImportedModel implements ImportedMlModel {
*/
public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
- /** Returns an immutable list of possibly non-fatal warnings encountered during import. */
+ /**
+ * Returns an immutable list of possibly non-fatal warnings encountered during import.
+ */
public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
/** Returns the expression this output references as an imported function */
@@ -232,7 +247,7 @@ public class ImportedModel implements ImportedMlModel {
return new ImportedMlFunction(functionName,
new ArrayList<>(inputs.values()),
owner().expressions().get(outputs.get(outputName)).getRoot().toString(),
- asStrings(inputMap()),
+ asTensorTypeStrings(inputMap()),
Optional.empty());
}