summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-06-01 15:07:51 +0200
committerLester Solbakken <lesters@oath.com>2018-06-01 15:07:51 +0200
commit75757692132e71c3034a761bb4eb04768e4e1268 (patch)
tree7afcdd570b65d551b9047ef50c407a1b8048c45a /searchlib
parentc5089d4c8f5c6259190ecbf80fbe0c96f391c218 (diff)
Add logging of variable types
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java2
3 files changed, 18 insertions, 17 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
index 9fe45194423..dc70e694446 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
@@ -16,9 +17,12 @@ import java.io.File;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.logging.Logger;
public abstract class ModelImporter {
+ private static final Logger log = Logger.getLogger(ModelImporter.class.getName());
+
/**
* The main import function.
*/
@@ -119,20 +123,6 @@ public abstract class ModelImporter {
return operation.function();
}
-// Tensor tensor;
-// if (operation.getConstantValue().isPresent()) {
-// Value value = operation.getConstantValue().get();
-// if ( ! (value instanceof TensorValue)) {
-// return operation.function(); // scalar values are inserted directly into the expression
-// }
-// tensor = value.asTensor();
-// } else {
-// // Here we use the type from the operation, which will have correct dimension names after name resolving
-// tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle),
-// operation.type().get());
-// operation.setConstantValue(new TensorValue(tensor));
-// }
-
Value value = operation.getConstantValue().orElseThrow(() ->
new IllegalArgumentException("Operation '" + operation.vespaName() + "' " +
"is constant but does not have a value."));
@@ -221,8 +211,20 @@ public abstract class ModelImporter {
}
}
+ /**
+ * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
+ * This allows users to learn the exact types (including dimension order after renaming) of the Variables
+ * such that these can be converted and fed to a parent document independently of the rest of the model
+ * for fast model weight updates.
+ */
private static void logVariableTypes(IntermediateGraph graph, ImportedModel model) {
- // todo
+ for (IntermediateOperation operation : graph.operations()) {
+ if ( ! (operation instanceof Constant)) continue;
+ if ( ! operation.type().isPresent()) continue; // will not happen
+
+ log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() +
+ " of type " + operation.type().get());
+ }
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
index e5f8728ebe2..18856d4a25f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
@@ -28,7 +28,6 @@ public class TensorConverter {
return builder.build();
}
- /* todo: support more types */
private static Values readValuesOf(Onnx.TensorProto tensorProto) {
if (tensorProto.hasRawData()) {
switch (tensorProto.getDataType()) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
index 822656916f8..95a77c07590 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
@@ -55,7 +55,7 @@ public class Mean extends IntermediateOperation {
return reducedType(inputType, shouldKeepDimensions());
}
- // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity.
+ // optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
protected TensorFunction lazyGetFunction() {