diff options
author | Lester Solbakken <lesters@oath.com> | 2018-06-01 15:07:51 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-06-01 15:07:51 +0200 |
commit | 75757692132e71c3034a761bb4eb04768e4e1268 (patch) | |
tree | 7afcdd570b65d551b9047ef50c407a1b8048c45a /searchlib | |
parent | c5089d4c8f5c6259190ecbf80fbe0c96f391c218 (diff) |
Add logging of variable types
Diffstat (limited to 'searchlib')
3 files changed, 18 insertions, 17 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index 9fe45194423..dc70e694446 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; @@ -16,9 +17,12 @@ import java.io.File; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.logging.Logger; public abstract class ModelImporter { + private static final Logger log = Logger.getLogger(ModelImporter.class.getName()); + /** * The main import function. */ @@ -119,20 +123,6 @@ public abstract class ModelImporter { return operation.function(); } -// Tensor tensor; -// if (operation.getConstantValue().isPresent()) { -// Value value = operation.getConstantValue().get(); -// if ( ! (value instanceof TensorValue)) { -// return operation.function(); // scalar values are inserted directly into the expression -// } -// tensor = value.asTensor(); -// } else { -// // Here we use the type from the operation, which will have correct dimension names after name resolving -// tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle), -// operation.type().get()); -// operation.setConstantValue(new TensorValue(tensor)); -// } - Value value = operation.getConstantValue().orElseThrow(() -> new IllegalArgumentException("Operation '" + operation.vespaName() + "' " + "is constant but does not have a value.")); @@ -221,8 +211,20 @@ public abstract class ModelImporter { } } + /** + * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type. + * This allows users to learn the exact types (including dimension order after renaming) of the Variables + * such that these can be converted and fed to a parent document independently of the rest of the model + * for fast model weight updates. + */ private static void logVariableTypes(IntermediateGraph graph, ImportedModel model) { - // todo + for (IntermediateOperation operation : graph.operations()) { + if ( ! (operation instanceof Constant)) continue; + if ( ! operation.type().isPresent()) continue; // will not happen + + log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() + + " of type " + operation.type().get()); + } } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java index e5f8728ebe2..18856d4a25f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java @@ -28,7 +28,6 @@ public class TensorConverter { return builder.build(); } - /* todo: support more types */ private static Values readValuesOf(Onnx.TensorProto tensorProto) { if (tensorProto.hasRawData()) { switch (tensorProto.getDataType()) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java index 822656916f8..95a77c07590 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java @@ -55,7 +55,7 @@ public class Mean extends IntermediateOperation { return reducedType(inputType, shouldKeepDimensions()); } - // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity. + // optimization: if keepDims and one reduce dimension that has size 1: same as identity. @Override protected TensorFunction lazyGetFunction() { |