summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-11-28 13:02:02 +0100
committerJon Bratseth <bratseth@oath.com>2018-11-28 13:02:02 +0100
commit7cbf6c75962e44f76c64c2442b45c83fe275fdcb (patch)
tree608e5d93e62e55c24bddca35f874d5100466e214 /model-integration
parente12e2d54042b2aeca632ee630f0d67695dfb2f1b (diff)
Always output tensor type in toString
This allows us to always restore a tensor accurately from its toString form.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java35
1 files changed, 10 insertions, 25 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 0c5866b87fa..9461c391951 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -72,7 +72,7 @@ public class ImportedModel implements ImportedMlModel {
* These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
*/
@Override
- public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); }
+ public Map<String, String> smallConstants() { return asStrings(smallConstants); }
boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); }
@@ -82,7 +82,7 @@ public class ImportedModel implements ImportedMlModel {
* For TensorFlow this corresponds to Variable files stored separately.
*/
@Override
- public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); }
+ public Map<String, String> largeConstants() { return asStrings(largeConstants); }
boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); }
@@ -136,7 +136,7 @@ public class ImportedModel implements ImportedMlModel {
functions.add(new ImportedMlFunction(signatureEntry.getKey(),
new ArrayList<>(signatureEntry.getValue().inputs().values()),
expressions().get(signatureEntry.getKey()).getRoot().toString(),
- asTensorTypeStrings(signatureEntry.getValue().inputMap()),
+ asStrings(signatureEntry.getValue().inputMap()),
Optional.empty()));
}
if (signatures().isEmpty()) { // fallback for models without signatures
@@ -145,7 +145,7 @@ public class ImportedModel implements ImportedMlModel {
functions.add(new ImportedMlFunction(singleEntry.getKey(),
new ArrayList<>(inputs.keySet()),
singleEntry.getValue().getRoot().toString(),
- asTensorTypeStrings(inputs),
+ asStrings(inputs),
Optional.empty()));
}
else {
@@ -153,7 +153,7 @@ public class ImportedModel implements ImportedMlModel {
functions.add(new ImportedMlFunction(expressionEntry.getKey(),
new ArrayList<>(inputs.keySet()),
expressionEntry.getValue().getRoot().toString(),
- asTensorTypeStrings(inputs),
+ asStrings(inputs),
Optional.empty()));
}
}
@@ -161,26 +161,13 @@ public class ImportedModel implements ImportedMlModel {
return functions;
}
- private Map<String, String> asTensorStrings(Map<String, Tensor> map) {
+ private Map<String, String> asStrings(Map<String, ?> map) {
HashMap<String, String> values = new HashMap<>();
- for (Map.Entry<String, Tensor> entry : map.entrySet()) {
- Tensor tensor = entry.getValue();
- // TODO: See Tensor.toStandardString
- if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty())
- values.put(entry.getKey(), tensor.toString());
- else
- values.put(entry.getKey(), tensor.type() + ":" + tensor);
- }
+ for (Map.Entry<String, ?> entry : map.entrySet())
+ values.put(entry.getKey(), entry.getValue().toString());
return values;
}
- private static Map<String, String> asTensorTypeStrings(Map<String, TensorType> map) {
- Map<String, String> stringMap = new HashMap<>();
- for (Map.Entry<String, TensorType> entry : map.entrySet())
- stringMap.put(entry.getKey(), entry.getValue().toString());
- return stringMap;
- }
-
private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) {
HashMap<String, String> values = new HashMap<>();
for (Map.Entry<String, RankingExpression> entry : map.entrySet())
@@ -240,9 +227,7 @@ public class ImportedModel implements ImportedMlModel {
*/
public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
- /**
- * Returns an immutable list of possibly non-fatal warnings encountered during import.
- */
+ /** Returns an immutable list of possibly non-fatal warnings encountered during import. */
public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
/** Returns the expression this output references */
@@ -259,7 +244,7 @@ public class ImportedModel implements ImportedMlModel {
return new ImportedMlFunction(functionName,
new ArrayList<>(inputs.values()),
owner().expressions().get(outputs.get(outputName)).getRoot().toString(),
- asTensorTypeStrings(inputMap()),
+ asStrings(inputMap()),
Optional.empty());
}