summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-11-29 07:24:41 -0800
committerGitHub <noreply@github.com>2018-11-29 07:24:41 -0800
commitf53959bddcd482b262276c4d3ae75fa754f82394 (patch)
tree3557fb25956d8b65273d77c1903cb8097c06bb1a /model-integration
parent71eccb08d9b88fb4b5485baed3fdc2cc17822a6f (diff)
Revert "Revert "Always output tensor type in toString""
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java35
1 files changed, 10 insertions, 25 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index ce91b8bb141..d7ac8bc90b2 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -72,7 +72,7 @@ public class ImportedModel implements ImportedMlModel {
* These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
*/
@Override
- public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); }
+ public Map<String, String> smallConstants() { return asStrings(smallConstants); }
boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); }
@@ -82,7 +82,7 @@ public class ImportedModel implements ImportedMlModel {
* For TensorFlow this corresponds to Variable files stored separately.
*/
@Override
- public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); }
+ public Map<String, String> largeConstants() { return asStrings(largeConstants); }
boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); }
@@ -133,7 +133,7 @@ public class ImportedModel implements ImportedMlModel {
functions.add(new ImportedMlFunction(signatureEntry.getKey(),
new ArrayList<>(signatureEntry.getValue().inputs().values()),
expressions().get(signatureEntry.getKey()).getRoot().toString(),
- asTensorTypeStrings(signatureEntry.getValue().inputMap()),
+ asStrings(signatureEntry.getValue().inputMap()),
Optional.empty()));
}
if (signatures().isEmpty()) { // fallback for models without signatures
@@ -142,7 +142,7 @@ public class ImportedModel implements ImportedMlModel {
functions.add(new ImportedMlFunction(singleEntry.getKey(),
new ArrayList<>(inputs.keySet()),
singleEntry.getValue().getRoot().toString(),
- asTensorTypeStrings(inputs),
+ asStrings(inputs),
Optional.empty()));
}
else {
@@ -150,7 +150,7 @@ public class ImportedModel implements ImportedMlModel {
functions.add(new ImportedMlFunction(expressionEntry.getKey(),
new ArrayList<>(inputs.keySet()),
expressionEntry.getValue().getRoot().toString(),
- asTensorTypeStrings(inputs),
+ asStrings(inputs),
Optional.empty()));
}
}
@@ -158,26 +158,13 @@ public class ImportedModel implements ImportedMlModel {
return functions;
}
- private Map<String, String> asTensorStrings(Map<String, Tensor> map) {
+ private Map<String, String> asStrings(Map<String, ?> map) {
HashMap<String, String> values = new HashMap<>();
- for (Map.Entry<String, Tensor> entry : map.entrySet()) {
- Tensor tensor = entry.getValue();
- // TODO: See Tensor.toStandardString
- if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty())
- values.put(entry.getKey(), tensor.toString());
- else
- values.put(entry.getKey(), tensor.type() + ":" + tensor);
- }
+ for (Map.Entry<String, ?> entry : map.entrySet())
+ values.put(entry.getKey(), entry.getValue().toString());
return values;
}
- private static Map<String, String> asTensorTypeStrings(Map<String, TensorType> map) {
- Map<String, String> stringMap = new HashMap<>();
- for (Map.Entry<String, TensorType> entry : map.entrySet())
- stringMap.put(entry.getKey(), entry.getValue().toString());
- return stringMap;
- }
-
private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) {
HashMap<String, String> values = new HashMap<>();
for (Map.Entry<String, RankingExpression> entry : map.entrySet())
@@ -237,9 +224,7 @@ public class ImportedModel implements ImportedMlModel {
*/
public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
- /**
- * Returns an immutable list of possibly non-fatal warnings encountered during import.
- */
+ /** Returns an immutable list of possibly non-fatal warnings encountered during import. */
public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
/** Returns the expression this output references as an imported function */
@@ -247,7 +232,7 @@ public class ImportedModel implements ImportedMlModel {
return new ImportedMlFunction(functionName,
new ArrayList<>(inputs.values()),
owner().expressions().get(outputs.get(outputName)).getRoot().toString(),
- asTensorTypeStrings(inputMap()),
+ asStrings(inputMap()),
Optional.empty());
}