diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-12-13 15:21:44 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-12-13 15:21:44 +0100 |
commit | 3783a9b21f8ab7ca3700903d9780a9f7374cf0c5 (patch) | |
tree | ec003528946a37b9f0aeb49e1b314fdc6601c26e /searchlib/src | |
parent | 5b67e6f8f641141f848ad3989156151f9f182441 (diff) |
Check agreement between TF and Vespa execution
Diffstat (limited to 'searchlib/src')
24 files changed, 1767 insertions, 1441 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 785ed78492e..0eeb0a9e630 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.tensor.Tensor; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Set; @@ -18,26 +19,30 @@ public abstract class Context implements EvaluationContext { /** * <p>Returns the value of a simple variable name.</p> * - * @param name The name of the variable whose value to return. - * @return The value of the named variable. + * @param name the name of the variable whose value to return. + * @return the value of the named variable. */ public abstract Value get(String name); + /** Returns a variable as a tensor */ + @Override + public Tensor getTensor(String name) { return get(name).asTensor(); } + /** * <p>Returns the value of a <i>structured variable</i> on the form * <code>name(argument*)(.output)?</code>, where <i>argument</i> is any * string. This may be used to implement more advanced variables whose * values are calculated at runtime from arguments. Supporting this in a - * context is optional. - * + * context is optional. + * * <p>This default implementation generates a name on the form * <code>name(argument1, argument2, ...argumentN).output</code>. * If there are no arguments the parenthesis are omitted. * If there is no output, the dot is omitted.</p> * - * @param name The name of this variable. - * @param arguments The parsed arguments as given in the textual expression. - * @param output The name of the value to output (to enable one named + * @param name the name of this variable. + * @param arguments the parsed arguments as given in the textual expression. + * @param output the name of the value to output (to enable one named * calculation to output several), or null to output the * "main" (or only) value. */ @@ -54,20 +59,20 @@ public abstract class Context implements EvaluationContext { * context subclasses. This default implementation throws * UnsupportedOperationException.</p> * - * @param index The index of the variable whose value to return. - * @return The value of the indexed variable. + * @param index the index of the variable whose value to return. + * @return the value of the indexed variable. */ public Value get(int index) { throw new UnsupportedOperationException(this + " does not support variable lookup by index"); } /** - * <p>Lookup by index rather than name directly to a double. This is supported by some optimized + * Lookup by index rather than name directly to a double. This is supported by some optimized * context subclasses. This default implementation throws - * UnsupportedOperationException.</p> + * UnsupportedOperationException. * - * @param index The index of the variable whose value to return. - * @return The value of the indexed variable. + * @param index the index of the variable whose value to return. + * @return the value of the indexed variable. */ public double getDouble(int index) { throw new UnsupportedOperationException(this + " does not support variable lookup by index"); @@ -81,24 +86,23 @@ public abstract class Context implements EvaluationContext { } /** - * <p>Sets a value to this, or throws an UnsupportedOperationException if - * this is not supported. This default implementation does the latter.</p> * + * Sets a value to this, or throws an UnsupportedOperationException if + * this is not supported. This default implementation does the latter. * - * @param name The name of the variable to set. + * @param name the name of the variable to set. * @param value the value to set. Ownership of this value is transferred to this - if it is mutable * (not frozen) it may be modified during execution - * @since 5.1.5 */ public void put(String name, Value value) { throw new UnsupportedOperationException(this + " does not support variable assignment"); } /** - * <p>Returns all the names available in this, or throws an + * Returns all the names available in this, or throws an * UnsupportedOperationException if this operation is not supported. This - * default implementation does the latter.</p> + * default implementation does the latter. * - * @return The set of all variable names. + * @return the set of all variable names. */ public Set<String> names() { throw new UnsupportedOperationException(this + " does not support return a list of its names"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index ea750295423..2ef4a2ede2f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -3,6 +3,9 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; /** * A value which acts as a double in numerical context. @@ -16,6 +19,11 @@ public abstract class DoubleCompatibleValue extends Value { public boolean hasDouble() { return true; } @Override + public Tensor asTensor() { + return doubleAsTensor(asDouble()); + } + + @Override public Value negate() { return new DoubleValue(-asDouble()); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index ac8aba6a617..dad69b31181 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -4,12 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; /** * A string value. * * @author bratseth - * @since 5.1.21 */ public class StringValue extends Value { @@ -35,6 +37,11 @@ public class StringValue extends Value { } @Override + public Tensor asTensor() { + return doubleAsTensor(asDouble()); + } + + @Override public boolean hasDouble() { return true; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 49c3ccb7b01..26c30fe5ed2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -2,14 +2,10 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.google.common.annotations.Beta; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; -import com.yahoo.tensor.TensorType; - -import java.util.Collections; -import java.util.Optional; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; /** * A Value containing a tensor. @@ -23,7 +19,7 @@ public class TensorValue extends Value { /** The tensor value of this */ private final Tensor value; - + public TensorValue(Tensor value) { this.value = value; } @@ -131,7 +127,7 @@ public class TensorValue extends Value { public Value compare(TruthOperator operator, Value argument) { return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString()))); } - + private Tensor compareTensor(TruthOperator operator, Tensor argument) { switch (operator) { case LARGER: return value.larger(argument); @@ -152,7 +148,7 @@ public class TensorValue extends Value { else return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble()))); } - + private Tensor functionOnTensor(Function function, Tensor argument) { switch (function) { case min: return value.min(argument); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index b2ccbe572d0..40d70e0022c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -5,6 +5,8 @@ import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; /** * The result of a ranking expression evaluation. @@ -25,6 +27,14 @@ public abstract class Value { return new DoubleValue(asDouble()); } + /** Returns this as a tensor value */ + public abstract Tensor asTensor(); + + /** A utility method for wrapping a sdouble in a rank 0 tensor */ + protected Tensor doubleAsTensor(double value) { + return Tensor.Builder.of(TensorType.empty).cell(TensorAddress.of(), value).build(); + } + /** Returns true if this value can return itself as a double, i.e asDoubleValue will return a value and not throw */ public abstract boolean hasDouble(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java new file mode 100644 index 00000000000..b4a9b363ade --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java @@ -0,0 +1,51 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * The result of importing a TensorFlow model into Vespa: + * - A list of ranking expressions reproducing the computations of the outputs in the TensorFlow model + * - A list of named constant tensors + * - A list of expected input tensors, with their tensor type + * - A list of warning messages + * + * @author bratseth + */ +// This object can be built incrementally within this package, but is immutable when observed from outside the package +// TODO: Retain signature structure in ImportResult (input + output-expression bundles) +public class ImportResult { + + private final List<RankingExpression> expressions = new ArrayList<>(); + private final Map<String, Tensor> constants = new HashMap<>(); + private final Map<String, TensorType> arguments = new HashMap<>(); + private final List<String> warnings = new ArrayList<>(); + + void add(RankingExpression expression) { expressions.add(expression); } + void set(String name, Tensor constant) { constants.put(name, constant); } + void set(String name, TensorType argument) { arguments.put(name, argument); } + void warn(String warning) { warnings.add(warning); } + + /** Returns an immutable list of the expressions of this */ + public List<RankingExpression> expressions() { return Collections.unmodifiableList(expressions); } + + /** Returns an immutable map of the constants of this */ + public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); } + + /** Returns an immutable map of the arguments of this */ + public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } + + /** Returns an immutable list, in natural sort order of the warnings generated while importing this */ + public List<String> warnings() { + return warnings.stream().sorted().collect(Collectors.toList()); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java deleted file mode 100644 index 235771bfa9c..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java +++ /dev/null @@ -1,23 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.tensor.Tensor; - -/** - * A tensor with a name - * - * @author bratseth - */ -public class NamedTensor { - - private final String name; - private final Tensor tensor; - - public NamedTensor(String name, Tensor tensor) { - this.name = name; - this.tensor = tensor; - } - - public String name() { return name; } - public Tensor tensor() { return tensor; } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index 183cfabbd87..e7f7b5ef2f4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -10,13 +10,14 @@ import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Matmul; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.Softmax; +import com.yahoo.tensor.functions.TensorFunction; import org.tensorflow.SavedModelBundle; import org.tensorflow.Session; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.NodeDef; +import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; @@ -45,16 +46,35 @@ class OperationMapper { private TensorConverter tensorConverter = new TensorConverter(); TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) { - // Note that this generalizes the corresponding TF function as it does not verify that the tensor - // types are the same, with the assumption that this already happened on the TF side - // (and if not, this should do the right thing anyway) ensureArguments(2, arguments, "join"); TypedTensorFunction a = arguments.get(0); TypedTensorFunction b = arguments.get(1); + if (a.type().rank() < b.type().rank()) + throw new IllegalArgumentException("Attempt to join " + a.type() + " and " + b.type() + ", " + + "but this is not supported when the second argument has a higher rank"); + + TensorFunction bFunction = b.function(); + + if (a.type().rank() > b.type().rank()) { + // Well now we have entered the wonderful world of "broadcasting" + // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html + // I'm not able to extract from that any unambiguous specification of which dimensions + // should be "stretched" when the tensor do not have the same number of dimensions. + // From trying this with TensorFlow it appears that the second tensor is matched to the + // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. + // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). + List<String> renameFrom = new ArrayList<>(); + List<String> renameTo = new ArrayList<>(); + int sizeDifference = a.type().rank() - b.type().rank(); + for (int i = 0; i < b.type().rank(); i++) { + renameFrom.add(b.type().dimensions().get(i).name()); + renameTo.add("d" + (sizeDifference + i)); + } + bFunction = new Rename(bFunction, renameFrom, renameTo); + } - TensorType resultType = Join.outputType(a.type(), b.type()); - Join function = new Join(a.function(), b.function(), doubleFunction); - return new TypedTensorFunction(resultType, function); + Join function = new Join(a.function(), bFunction, doubleFunction); + return new TypedTensorFunction(a.type(), function); // output type is a type by TF definition and a.rank>=b.rank } TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) { @@ -66,35 +86,37 @@ class OperationMapper { return new TypedTensorFunction(resultType, function); } - TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs, SavedModelBundle model, - List<NamedTensor> constants) { - String name; - TensorType type; - if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model - if (tfNode.getInputList().size() != 1) - throw new IllegalArgumentException("A Variable/read node must have one input but has " + - tfNode.getInputList().size()); - name = tfNode.getInput(0); - AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); - if (shapes == null) - throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape"); - Session.Runner fetched = model.session().runner().fetch(name); - List<org.tensorflow.Tensor<?>> result = fetched.run(); - if ( result.size() != 1) - throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + result.size()); - Tensor constant = tensorConverter.toVespaTensor(result.get(0)); - constants.add(new NamedTensor(name, constant)); - return new TypedTensorFunction(constant.type(), - new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")"))); - } - else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name - name = tfNode.getName(); - type = inputs.get(name); - if (type == null) - throw new IllegalArgumentException("An identity operation node is referencing input '" + name + - "', but there is no such input"); - return new TypedTensorFunction(type, new VariableTensor(name)); - } + TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) { + String name = tfNode.getName(); + TensorType type = result.arguments().get(name); + if (type == null) + throw new IllegalArgumentException("An placeholder operation node is referencing input '" + name + + "', but there is no such input"); + // Included literally in the expression and so must be produced by a separate macro in the rank profile + return new TypedTensorFunction(type, new VariableTensor(name)); + } + + TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) { + if ( ! tfNode.getName().endsWith("/read")) + throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " + + "nodes are only supported when reading variables"); + if (tfNode.getInputList().size() != 1) + throw new IllegalArgumentException("A Variable/read node must have one input but has " + + tfNode.getInputList().size()); + + String name = tfNode.getInput(0); + AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); + if (shapes == null) + throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape"); + Session.Runner fetched = model.session().runner().fetch(name); + List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); + if ( importedTensors.size() != 1) + throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + + importedTensors.size()); + Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0)); + result.set(name, constant); + return new TypedTensorFunction(constant.type(), + new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")"))); } TypedTensorFunction matmul(List<TypedTensorFunction> arguments) { @@ -106,21 +128,18 @@ class OperationMapper { if (a.type().rank() != b.type().rank()) throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - // Let the second-to-last dimension of the second tensor be the same as the last dimension of the first - // and the last dimension of the second argument be not present in the first argument, while leaving the + String afterLastDim = "d" + (a.type().rank() + 1); + // Let the first dimension of the second tensor be the same as the second dimension of the first + // and the second dimension of the second argument be not present in the first argument, while leaving the // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication. - // TODO: Check if transpose_a or transpose_b is set and rename differently accordingly - - String beforeLastDim = "d" + (a.type().rank() - 1); - String lastDim = "d" + a.type().rank(); - String afterLastDim = "d" + (a.type().rank() + 1); + // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly - Rename renamedB = new Rename(b.function(), ImmutableList.of(beforeLastDim, lastDim), - ImmutableList.of(lastDim, afterLastDim)); - Matmul matmul = new Matmul(a.function(), renamedB, lastDim); - return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), lastDim), - new Rename(matmul, afterLastDim, lastDim)); + Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"), + ImmutableList.of("d1", afterLastDim)); + Matmul matmul = new Matmul(a.function(), renamedB, "d1"); + return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"), + new Rename(matmul, afterLastDim, "d1")); } TypedTensorFunction softmax(List<TypedTensorFunction> arguments) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java index a74445008b7..df43225c333 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java @@ -24,8 +24,10 @@ public class TensorConverter { private TensorType toVespaTensorType(long[] shape) { TensorType.Builder b = new TensorType.Builder(); int dimensionIndex = 0; - for (long dimensionSize : shape) - b.indexed("d" + (dimensionIndex++), (int)dimensionSize); + for (long dimensionSize : shape) { + if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... + b.indexed("d" + (dimensionIndex++), (int) dimensionSize); + } return b.build(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 51f1e444e70..33523244129 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -16,11 +16,8 @@ import org.tensorflow.framework.TensorInfo; import org.tensorflow.framework.TensorShapeProto; import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.logging.Level; import java.util.stream.Collectors; /** @@ -35,104 +32,100 @@ public class TensorFlowImporter { /** * Imports a saved TensorFlow model from a directory. * The model should be saved as a pbtxt file. - * The name of the model is taken at the pbtxt file name (not including the .pbtxt ending). + * The name of the model is taken as the db/pbtxt file name (not including the file ending). * * @param modelDir the directory containing the TensorFlow model files to import - * @param constants any constant tensors imported from the TensorFlow model and referenced in the returned expressions - * @param logger a receiver of any messages generated by the import process - * @return the ranking expressions resulting from importing this TenorFlow model */ - public List<RankingExpression> importModel(String modelDir, List<NamedTensor> constants, MessageLogger logger) { - try { - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model, constants, logger); + public ImportResult importModel(String modelDir) { + try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { + return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model); } catch (IOException e) { - throw new IllegalArgumentException("Could not open TensorFlow model directory '" + modelDir + "'", e); + throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e); } + } + public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) { + try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { + MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef()); + SignatureDef signature = graph.getSignatureDefMap().get(inputSignatureName); + ImportResult result = new ImportResult(); + importInputs(signature.getInputsMap(), result); + result.add(new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result))); + return result; + } + catch (IOException e) { + throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e); + } } - private List<RankingExpression> importGraph(MetaGraphDef graph, SavedModelBundle model, - List<NamedTensor> constants, MessageLogger logger) { - List<RankingExpression> expressions = new ArrayList<>(); + private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) { + ImportResult result = new ImportResult(); for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { - Map<String, TensorType> inputs = importInputs(signatureEntry.getValue().getInputsMap()); + importInputs(signatureEntry.getValue().getInputsMap(), result); for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { try { - ExpressionNode result = importOutput(output.getValue(), - inputs, - graph.getGraphDef(), - model, - constants); - expressions.add(new RankingExpression(output.getKey(), result)); + ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result); + result.add(new RankingExpression(output.getKey(), node)); } catch (IllegalArgumentException e) { - logger.log(Level.INFO, "Skipping output '" + output.getValue().getName() + "' of signature '" + - signatureEntry.getValue().getMethodName() + - "': " + Exceptions.toMessageString(e)); + result.warn("Skipping output '" + output.getValue().getName() + "' of signature '" + + signatureEntry.getValue().getMethodName() + + "': " + Exceptions.toMessageString(e)); } } } - return expressions; + return result; } - private Map<String, TensorType> importInputs(Map<String, TensorInfo> inputInfoMap) { - Map<String, TensorType> inputs = new HashMap<>(); - inputInfoMap.forEach((key, value) -> inputs.put(nameOf(value.getName()), + private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult result) { + inputInfoMap.forEach((key, value) -> result.set(nameOf(value.getName()), importTensorType(value.getTensorShape()))); - return inputs; } private TensorType importTensorType(TensorShapeProto tensorShape) { TensorType.Builder b = new TensorType.Builder(); - for (int i = 0; i < tensorShape.getDimCount(); i++) { - int dimensionSize = (int) tensorShape.getDim(i).getSize(); + for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) { + int dimensionSize = (int)dimension.getSize(); if (dimensionSize >= 0) - b.indexed("d" + i, dimensionSize); + b.indexed("d" + b.rank(), dimensionSize); else - b.indexed("d" + i); // unbound size + b.indexed("d" + b.rank()); // unbound size } return b.build(); } - private ExpressionNode importOutput(TensorInfo output, Map<String, TensorType> inputs, GraphDef graph, - SavedModelBundle model, List<NamedTensor> constants) { - NodeDef node = getNode(nameOf(output.getName()), graph); - TensorFunction function = importNode(node, inputs, graph, model, constants).function(); + private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) { + return importNode(nameOf(output.getName()), graph, model, result); + } + + private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) { + TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function(); return new TensorFunctionNode(function); // wrap top level (only) as an expression } /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, - SavedModelBundle model, List<NamedTensor> constants) { - return tensorFunctionOf(tfNode, inputs, graph, model, constants); + private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + return tensorFunctionOf(tfNode, graph, model, result); } - private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, - Map<String, TensorType> inputs, - GraphDef graph, - SavedModelBundle model, - List<NamedTensor> constants) { + private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ switch (tfNode.getOp().toLowerCase()) { - case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, model, constants), ScalarFunctions.add()); - case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, model, constants), ScalarFunctions.acos()); - case "identity" : return operationMapper.identity(tfNode, inputs, model, constants); - case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, model, constants)); - case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, model, constants)); + case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add()); + case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos()); + case "placeholder" : return operationMapper.placeholder(tfNode, result); + case "identity" : return operationMapper.identity(tfNode, model, result); + case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result)); + case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result)); default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); } } - private List<TypedTensorFunction> importArguments(NodeDef tfNode, - Map<String, TensorType> inputs, - GraphDef graph, - SavedModelBundle model, - List<NamedTensor> constants) { + private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { return tfNode.getInputList().stream() - .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, model, constants)) + .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result)) .collect(Collectors.toList()); } @@ -151,11 +144,4 @@ public class TensorFlowImporter { return name.split(":")[0]; } - /** An interface which can be implemented to receive messages emitted during import */ - public interface MessageLogger { - - void log(Level level, String message); - - } - } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java index 234d620d02f..5712da77700 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java @@ -3,9 +3,9 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; -/** +/** * A tensor function returning a specific tensor type - * + * * @author bratseth */ final class TypedTensorFunction { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java index 71699b379b2..d366c9bfbe5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -14,23 +14,23 @@ import java.util.function.*; /** * A tensor generating function, whose arguments are determined by a tensor type - * + * * @author bratseth */ public class GeneratorLambdaFunctionNode extends CompositeNode { private final TensorType type; private final ExpressionNode generator; - + public GeneratorLambdaFunctionNode(TensorType type, ExpressionNode generator) { if ( ! type.dimensions().stream().allMatch(d -> d.size().isPresent())) - throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " + + throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " + "dimensions, but tried to generate " + type); // TODO: Verify that the function only accesses the given arguments this.type = type; this.generator = generator; } - + @Override public List<ExpressionNode> children() { return Collections.singletonList(generator); @@ -53,8 +53,8 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { public Value evaluate(Context context) { return generator.evaluate(context); } - - /** + + /** * Returns this as an operator which converts a list of integers into a double */ public IntegerListToDoubleLambda asIntegerListToDoubleOperator() { @@ -70,7 +70,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { context.put(type.dimensions().get(i).name(), arguments.get(i)); return evaluate(context).asDouble(); } - + @Override public String toString() { return GeneratorLambdaFunctionNode.this.toString(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index d1f4cbddf6e..8af3448ca6f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -36,10 +36,17 @@ public class TensorFunctionNode extends CompositeNode { @Override public List<ExpressionNode> children() { return function.functionArguments().stream() - .map(f -> ((TensorFunctionExpressionNode)f).expression) + .map(this::toExpressionNode) .collect(Collectors.toList()); } + private ExpressionNode toExpressionNode(TensorFunction f) { + if (f instanceof TensorFunctionExpressionNode) + return ((TensorFunctionExpressionNode)f).expression; + else + return new TensorFunctionNode(f); + } + @Override public CompositeNode setChildren(List<ExpressionNode> children) { List<TensorFunction> wrappedChildren = children.stream() diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py new file mode 100644 index 00000000000..a1861a1c981 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py @@ -0,0 +1,89 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A very simple MNIST classifier. + +See extensive documentation at +https://www.tensorflow.org/get_started/mnist/beginners +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +from tensorflow.examples.tutorials.mnist import input_data + +import tensorflow as tf + +FLAGS = None + + +def main(_): + # Import data + mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) + + # Create the model + x = tf.placeholder(tf.float32, [None, 784]) + W = tf.Variable(tf.zeros([784, 10])) + b = tf.Variable(tf.zeros([10])) + y = tf.matmul(x, W) + b + + # Define loss and optimizer + y_ = tf.placeholder(tf.float32, [None, 10]) + + # The raw formulation of cross-entropy, + # + # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), + # reduction_indices=[1])) + # + # can be numerically unstable. + # + # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw + # outputs of 'y', and then average across the batch. + cross_entropy = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) + train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) + + sess = tf.InteractiveSession() + tf.global_variables_initializer().run() + # Train + for _ in range(1000): + batch_xs, batch_ys = mnist.train.next_batch(100) + sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) + + # Test trained model + correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + print(sess.run(accuracy, feed_dict={x: mnist.test.images, + y_: mnist.test.labels})) + + # Save the model + export_path = "saved" + print('Exporting trained model to ', export_path) + builder = tf.saved_model.builder.SavedModelBuilder(export_path) + signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':x}, outputs = {'y':y}) + builder.add_meta_graph_and_variables(sess, + [tf.saved_model.tag_constants.SERVING], + signature_def_map={'serving_default':signature}) + builder.save(as_text=True) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', + help='Directory for storing input data') + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/searchlib/src/test/files/integration/tensorflow/model1/saved_model.pbtxt b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt index e01688669a1..8100dfd594d 100644 --- a/searchlib/src/test/files/integration/tensorflow/model1/saved_model.pbtxt +++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt @@ -237,6 +237,45 @@ meta_graphs { } } op { + name: "ConcatV2" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + input_arg { + name: "axis" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 2 + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { name: "Const" output_arg { name: "output" @@ -291,14 +330,14 @@ meta_graphs { is_commutative: true } op { - name: "Fill" + name: "ExpandDims" input_arg { - name: "dims" - type: DT_INT32 + name: "input" + type_attr: "T" } input_arg { - name: "value" - type_attr: "T" + name: "dim" + type_attr: "Tdim" } output_arg { name: "output" @@ -308,48 +347,28 @@ meta_graphs { name: "T" type: "type" } - } - op { - name: "HashTableV2" - output_arg { - name: "table_handle" - type: DT_RESOURCE - } - attr { - name: "container" - type: "string" - default_value { - s: "" - } - } attr { - name: "shared_name" - type: "string" + name: "Tdim" + type: "type" default_value { - s: "" + type: DT_INT32 } - } - attr { - name: "use_node_name_sharing" - type: "bool" - default_value { - b: false + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } } } - attr { - name: "key_dtype" - type: "type" - } - attr { - name: "value_dtype" - type: "type" - } - is_stateful: true } op { - name: "Identity" + name: "Fill" input_arg { - name: "input" + name: "dims" + type: DT_INT32 + } + input_arg { + name: "value" type_attr: "T" } output_arg { @@ -362,37 +381,17 @@ meta_graphs { } } op { - name: "InitializeTableV2" - input_arg { - name: "table_handle" - type: DT_RESOURCE - } - input_arg { - name: "keys" - type_attr: "Tkey" - } + name: "FloorDiv" input_arg { - name: "values" - type_attr: "Tval" - } - attr { - name: "Tkey" - type: "type" - } - attr { - name: "Tval" - type: "type" + name: "x" + type_attr: "T" } - is_stateful: true - } - op { - name: "Log" input_arg { - name: "x" + name: "y" type_attr: "T" } output_arg { - name: "y" + name: "z" type_attr: "T" } attr { @@ -403,6 +402,12 @@ meta_graphs { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 type: DT_COMPLEX64 type: DT_COMPLEX128 } @@ -410,32 +415,19 @@ meta_graphs { } } op { - name: "LookupTableFindV2" - input_arg { - name: "table_handle" - type: DT_RESOURCE - } - input_arg { - name: "keys" - type_attr: "Tin" - } + name: "Identity" input_arg { - name: "default_value" - type_attr: "Tout" + name: "input" + type_attr: "T" } output_arg { - name: "values" - type_attr: "Tout" - } - attr { - name: "Tin" - type: "type" + name: "output" + type_attr: "T" } attr { - name: "Tout" + name: "T" type: "type" } - is_stateful: true } op { name: "MatMul" @@ -481,6 +473,35 @@ meta_graphs { } } op { + name: "Maximum" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_commutative: true + } + op { name: "Mean" input_arg { name: "input" @@ -592,32 +613,6 @@ meta_graphs { is_commutative: true } op { - name: "Neg" - input_arg { - name: "x" - type_attr: "T" - } - output_arg { - name: "y" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - } - op { name: "NoOp" } op { @@ -650,88 +645,6 @@ meta_graphs { } } op { - name: "ParseExample" - input_arg { - name: "serialized" - type: DT_STRING - } - input_arg { - name: "names" - type: DT_STRING - } - input_arg { - name: "sparse_keys" - type: DT_STRING - number_attr: "Nsparse" - } - input_arg { - name: "dense_keys" - type: DT_STRING - number_attr: "Ndense" - } - input_arg { - name: "dense_defaults" - type_list_attr: "Tdense" - } - output_arg { - name: "sparse_indices" - type: DT_INT64 - number_attr: "Nsparse" - } - output_arg { - name: "sparse_values" - type_list_attr: "sparse_types" - } - output_arg { - name: "sparse_shapes" - type: DT_INT64 - number_attr: "Nsparse" - } - output_arg { - name: "dense_values" - type_list_attr: "Tdense" - } - attr { - name: "Nsparse" - type: "int" - has_minimum: true - } - attr { - name: "Ndense" - type: "int" - has_minimum: true - } - attr { - name: "sparse_types" - type: "list(type)" - has_minimum: true - allowed_values { - list { - type: DT_FLOAT - type: DT_INT64 - type: DT_STRING - } - } - } - attr { - name: "Tdense" - type: "list(type)" - has_minimum: true - allowed_values { - list { - type: DT_FLOAT - type: DT_INT64 - type: DT_STRING - } - } - } - attr { - name: "dense_shapes" - type: "list(shape)" - has_minimum: true - } - } - op { name: "Placeholder" output_arg { name: "output" @@ -752,22 +665,47 @@ meta_graphs { } } op { - name: "Range" - input_arg { - name: "start" - type_attr: "Tidx" - } + name: "Prod" input_arg { - name: "limit" - type_attr: "Tidx" + name: "input" + type_attr: "T" } input_arg { - name: "delta" + name: "reduction_indices" type_attr: "Tidx" } output_arg { name: "output" - type_attr: "Tidx" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } } attr { name: "Tidx" @@ -777,8 +715,6 @@ meta_graphs { } allowed_values { list { - type: DT_FLOAT - type: DT_DOUBLE type: DT_INT32 type: DT_INT64 } @@ -786,15 +722,19 @@ meta_graphs { } } op { - name: "Reciprocal" + name: "RealDiv" input_arg { name: "x" type_attr: "T" } - output_arg { + input_arg { name: "y" type_attr: "T" } + output_arg { + name: "z" + type_attr: "T" + } attr { name: "T" type: "type" @@ -803,6 +743,10 @@ meta_graphs { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 type: DT_INT32 type: DT_INT64 type: DT_COMPLEX64 @@ -943,13 +887,54 @@ meta_graphs { } } op { - name: "Softmax" + name: "Slice" input_arg { - name: "logits" + name: "input" + type_attr: "T" + } + input_arg { + name: "begin" + type_attr: "Index" + } + input_arg { + name: "size" + type_attr: "Index" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Index" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "SoftmaxCrossEntropyWithLogits" + input_arg { + name: "features" + type_attr: "T" + } + input_arg { + name: "labels" + type_attr: "T" + } + output_arg { + name: "loss" type_attr: "T" } output_arg { - name: "softmax" + name: "backprop" type_attr: "T" } attr { @@ -1011,6 +996,10 @@ meta_graphs { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 type: DT_INT32 type: DT_INT64 type: DT_COMPLEX64 @@ -1109,49 +1098,6 @@ meta_graphs { } } op { - name: "TopKV2" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "k" - type: DT_INT32 - } - output_arg { - name: "values" - type_attr: "T" - } - output_arg { - name: "indices" - type: DT_INT32 - } - attr { - name: "sorted" - type: "bool" - default_value { - b: true - } - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 - type: DT_UINT8 - type: DT_INT16 - type: DT_INT8 - type: DT_UINT16 - type: DT_HALF - } - } - } - } - op { name: "VariableV2" output_arg { name: "ref" @@ -1182,224 +1128,27 @@ meta_graphs { } is_stateful: true } - } - tags: "serve" - tensorflow_version: "1.3.0" - tensorflow_git_version: "v1.3.0-rc2-20-g0787eee" - } - graph_def { - node { - name: "tf_example" - op: "Placeholder" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "shape" - value { - shape { - unknown_rank: true - } - } - } - } - node { - name: "ParseExample/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - } - } - } - } - } - } - node { - name: "ParseExample/ParseExample/names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - } - } - } - } - } - } - node { - name: "ParseExample/ParseExample/dense_keys_0" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "x" - } - } - } - } - node { - name: "ParseExample/ParseExample" - op: "ParseExample" - input: "tf_example" - input: "ParseExample/ParseExample/names" - input: "ParseExample/ParseExample/dense_keys_0" - input: "ParseExample/Const" - attr { - key: "Ndense" - value { - i: 1 - } - } - attr { - key: "Nsparse" - value { - i: 0 - } - } - attr { - key: "Tdense" - value { - list { - type: DT_FLOAT - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - } - } - } - attr { - key: "dense_shapes" - value { - list { - shape { - dim { - size: 784 - } - } - } - } - } - attr { - key: "sparse_types" - value { - list { - } + op { + name: "ZerosLike" + input_arg { + name: "x" + type_attr: "T" } - } - } - node { - name: "x" - op: "Identity" - input: "ParseExample/ParseExample" - attr { - key: "T" - value { - type: DT_FLOAT + output_arg { + name: "y" + type_attr: "T" } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - } + attr { + name: "T" + type: "type" } } } + tags: "serve" + tensorflow_version: "1.4.1" + tensorflow_git_version: "v1.4.0-19-ga52c8d9b01" + } + graph_def { node { name: "Placeholder" op: "Placeholder" @@ -1412,7 +1161,7 @@ meta_graphs { size: -1 } dim { - size: 10 + size: 784 } } } @@ -1432,7 +1181,7 @@ meta_graphs { size: -1 } dim { - size: 10 + size: 784 } } } @@ -1767,15 +1516,9 @@ meta_graphs { } } node { - name: "init" - op: "NoOp" - input: "^Variable/Assign" - input: "^Variable_1/Assign" - } - node { name: "MatMul" op: "MatMul" - input: "x" + input: "Placeholder" input: "Variable/read" attr { key: "T" @@ -1839,15 +1582,8 @@ meta_graphs { } } node { - name: "y" - op: "Softmax" - input: "add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } + name: "Placeholder_1" + op: "Placeholder" attr { key: "_output_shapes" value { @@ -1863,38 +1599,60 @@ meta_graphs { } } } - } - node { - name: "Log" - op: "Log" - input: "y" attr { - key: "T" + key: "dtype" value { type: DT_FLOAT } } attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + node { + name: "Rank" + op: "Const" + attr { key: "_output_shapes" value { list { shape { - dim { - size: -1 - } - dim { - size: 10 - } } } } } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } } node { - name: "mul" - op: "Mul" - input: "Placeholder" - input: "Log" + name: "Shape" + op: "Shape" + input: "add" attr { key: "T" value { @@ -1907,27 +1665,27 @@ meta_graphs { list { shape { dim { - size: -1 - } - dim { - size: 10 + size: 2 } } } } } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } } node { - name: "Const" + name: "Rank_1" op: "Const" attr { key: "_output_shapes" value { list { shape { - dim { - size: 2 - } } } } @@ -1944,20 +1702,16 @@ meta_graphs { tensor { dtype: DT_INT32 tensor_shape { - dim { - size: 2 - } } - tensor_content: "\000\000\000\000\001\000\000\000" + int_val: 2 } } } } node { - name: "Sum" - op: "Sum" - input: "mul" - input: "Const" + name: "Shape_1" + op: "Shape" + input: "add" attr { key: "T" value { @@ -1965,11 +1719,27 @@ meta_graphs { } } attr { - key: "Tidx" + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" value { type: DT_INT32 } } + } + node { + name: "Sub/y" + op: "Const" attr { key: "_output_shapes" value { @@ -1980,20 +1750,32 @@ meta_graphs { } } attr { - key: "keep_dims" + key: "dtype" value { - b: false + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } } } } node { - name: "Neg" - op: "Neg" - input: "Sum" + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" attr { key: "T" value { - type: DT_FLOAT + type: DT_INT32 } } attr { @@ -2007,46 +1789,51 @@ meta_graphs { } } node { - name: "gradients/Shape" - op: "Const" + name: "Slice/begin" + op: "Pack" + input: "Sub" attr { - key: "_output_shapes" + key: "N" value { - list { - shape { - dim { - } - } - } + i: 1 } } attr { - key: "dtype" + key: "T" value { type: DT_INT32 } } attr { - key: "value" + key: "_output_shapes" value { - tensor { - dtype: DT_INT32 - tensor_shape { + list { + shape { dim { + size: 1 } } } } } + attr { + key: "axis" + value { + i: 0 + } + } } node { - name: "gradients/Const" + name: "Slice/size" op: "Const" attr { key: "_output_shapes" value { list { shape { + dim { + size: 1 + } } } } @@ -2054,50 +1841,40 @@ meta_graphs { attr { key: "dtype" value { - type: DT_FLOAT + type: DT_INT32 } } attr { key: "value" value { tensor { - dtype: DT_FLOAT + dtype: DT_INT32 tensor_shape { + dim { + size: 1 + } } - float_val: 1.0 + int_val: 1 } } } } node { - name: "gradients/Fill" - op: "Fill" - input: "gradients/Shape" - input: "gradients/Const" + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" + key: "Index" value { - list { - shape { - } - } + type: DT_INT32 } } - } - node { - name: "gradients/Neg_grad/Neg" - op: "Neg" - input: "gradients/Fill" attr { key: "T" value { - type: DT_FLOAT + type: DT_INT32 } } attr { @@ -2105,13 +1882,16 @@ meta_graphs { value { list { shape { + dim { + size: 1 + } } } } } } node { - name: "gradients/Sum_grad/Reshape/shape" + name: "concat/values_0" op: "Const" attr { key: "_output_shapes" @@ -2119,7 +1899,7 @@ meta_graphs { list { shape { dim { - size: 2 + size: 1 } } } @@ -2138,55 +1918,66 @@ meta_graphs { dtype: DT_INT32 tensor_shape { dim { - size: 2 + size: 1 } } - tensor_content: "\001\000\000\000\001\000\000\000" + int_val: -1 } } } } node { - name: "gradients/Sum_grad/Reshape" - op: "Reshape" - input: "gradients/Neg_grad/Neg" - input: "gradients/Sum_grad/Reshape/shape" + name: "concat/axis" + op: "Const" attr { - key: "T" + key: "_output_shapes" value { - type: DT_FLOAT + list { + shape { + } + } } } attr { - key: "Tshape" + key: "dtype" value { type: DT_INT32 } } attr { - key: "_output_shapes" + key: "value" value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } + tensor { + dtype: DT_INT32 + tensor_shape { } + int_val: 0 } } } } node { - name: "gradients/Sum_grad/Shape" - op: "Shape" - input: "mul" + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + attr { + key: "N" + value { + i: 2 + } + } attr { key: "T" value { - type: DT_FLOAT + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 } } attr { @@ -2201,18 +1992,12 @@ meta_graphs { } } } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } } node { - name: "gradients/Sum_grad/Tile" - op: "Tile" - input: "gradients/Sum_grad/Reshape" - input: "gradients/Sum_grad/Shape" + name: "Reshape" + op: "Reshape" + input: "add" + input: "concat" attr { key: "T" value { @@ -2220,7 +2005,7 @@ meta_graphs { } } attr { - key: "Tmultiples" + key: "Tshape" value { type: DT_INT32 } @@ -2234,7 +2019,7 @@ meta_graphs { size: -1 } dim { - size: 10 + size: -1 } } } @@ -2242,38 +2027,39 @@ meta_graphs { } } node { - name: "gradients/mul_grad/Shape" - op: "Shape" - input: "Placeholder" - attr { - key: "T" - value { - type: DT_FLOAT - } - } + name: "Rank_2" + op: "Const" attr { key: "_output_shapes" value { list { shape { - dim { - size: 2 - } } } } } attr { - key: "out_type" + key: "dtype" value { type: DT_INT32 } } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } } node { - name: "gradients/mul_grad/Shape_1" + name: "Shape_2" op: "Shape" - input: "Log" + input: "Placeholder_1" attr { key: "T" value { @@ -2300,43 +2086,44 @@ meta_graphs { } } node { - name: "gradients/mul_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/mul_grad/Shape" - input: "gradients/mul_grad/Shape_1" + name: "Sub_1/y" + op: "Const" attr { - key: "T" + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" value { type: DT_INT32 } } attr { - key: "_output_shapes" + key: "value" value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } + tensor { + dtype: DT_INT32 + tensor_shape { } + int_val: 1 } } } } node { - name: "gradients/mul_grad/mul" - op: "Mul" - input: "gradients/Sum_grad/Tile" - input: "Log" + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" attr { key: "T" value { - type: DT_FLOAT + type: DT_INT32 } } attr { @@ -2344,30 +2131,23 @@ meta_graphs { value { list { shape { - dim { - size: -1 - } - dim { - size: 10 - } } } } } } node { - name: "gradients/mul_grad/Sum" - op: "Sum" - input: "gradients/mul_grad/mul" - input: "gradients/mul_grad/BroadcastGradientArgs" + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" attr { - key: "T" + key: "N" value { - type: DT_FLOAT + i: 1 } } attr { - key: "Tidx" + key: "T" value { type: DT_INT32 } @@ -2377,60 +2157,72 @@ meta_graphs { value { list { shape { - unknown_rank: true + dim { + size: 1 + } } } } } attr { - key: "keep_dims" + key: "axis" value { - b: false + i: 0 } } } node { - name: "gradients/mul_grad/Reshape" - op: "Reshape" - input: "gradients/mul_grad/Sum" - input: "gradients/mul_grad/Shape" + name: "Slice_1/size" + op: "Const" attr { - key: "T" + key: "_output_shapes" value { - type: DT_FLOAT + list { + shape { + dim { + size: 1 + } + } + } } } attr { - key: "Tshape" + key: "dtype" value { type: DT_INT32 } } attr { - key: "_output_shapes" + key: "value" value { - list { - shape { - dim { - size: -1 - } + tensor { + dtype: DT_INT32 + tensor_shape { dim { - size: 10 + size: 1 } } + int_val: 1 } } } } node { - name: "gradients/mul_grad/mul_1" - op: "Mul" - input: "Placeholder" - input: "gradients/Sum_grad/Tile" + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } attr { key: "T" value { - type: DT_FLOAT + type: DT_INT32 } } attr { @@ -2439,10 +2231,7 @@ meta_graphs { list { shape { dim { - size: -1 - } - dim { - size: 10 + size: 1 } } } @@ -2450,44 +2239,113 @@ meta_graphs { } } node { - name: "gradients/mul_grad/Sum_1" - op: "Sum" - input: "gradients/mul_grad/mul_1" - input: "gradients/mul_grad/BroadcastGradientArgs:1" + name: "concat_1/values_0" + op: "Const" attr { - key: "T" + key: "_output_shapes" value { - type: DT_FLOAT + list { + shape { + dim { + size: 1 + } + } + } } } attr { - key: "Tidx" + key: "dtype" value { type: DT_INT32 } } attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat_1/axis" + op: "Const" + attr { key: "_output_shapes" value { list { shape { - unknown_rank: true } } } } attr { - key: "keep_dims" + key: "dtype" value { - b: false + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } } } } node { - name: "gradients/mul_grad/Reshape_1" + name: "Reshape_1" op: "Reshape" - input: "gradients/mul_grad/Sum_1" - input: "gradients/mul_grad/Shape_1" + input: "Placeholder_1" + input: "concat_1" attr { key: "T" value { @@ -2509,7 +2367,7 @@ meta_graphs { size: -1 } dim { - size: 10 + size: -1 } } } @@ -2517,16 +2375,10 @@ meta_graphs { } } node { - name: "gradients/mul_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/mul_grad/Reshape" - input: "^gradients/mul_grad/Reshape_1" - } - node { - name: "gradients/mul_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/mul_grad/Reshape" - input: "^gradients/mul_grad/tuple/group_deps" + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape" + input: "Reshape_1" attr { key: "T" value { @@ -2534,73 +2386,127 @@ meta_graphs { } } attr { - key: "_class" + key: "_output_shapes" value { list { - s: "loc:@gradients/mul_grad/Reshape" + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } } } } + } + node { + name: "Sub_2/y" + op: "Const" attr { key: "_output_shapes" value { list { shape { - dim { - size: -1 - } - dim { - size: 10 - } } } } } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } } node { - name: "gradients/mul_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/mul_grad/Reshape_1" - input: "^gradients/mul_grad/tuple/group_deps" + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" attr { key: "T" value { - type: DT_FLOAT + type: DT_INT32 } } attr { - key: "_class" + key: "_output_shapes" value { list { - s: "loc:@gradients/mul_grad/Reshape_1" + shape { + } } } } + } + node { + name: "Slice_2/begin" + op: "Const" attr { key: "_output_shapes" value { list { shape { dim { - size: -1 + size: 1 } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { dim { - size: 10 + size: 1 } } + int_val: 0 } } } } node { - name: "gradients/Log_grad/Reciprocal" - op: "Reciprocal" - input: "y" - input: "^gradients/mul_grad/tuple/control_dependency_1" + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + attr { + key: "N" + value { + i: 1 + } + } attr { key: "T" value { - type: DT_FLOAT + type: DT_INT32 } } attr { @@ -2609,25 +2515,35 @@ meta_graphs { list { shape { dim { - size: -1 - } - dim { - size: 10 + size: 1 } } } } } + attr { + key: "axis" + value { + i: 0 + } + } } node { - name: "gradients/Log_grad/mul" - op: "Mul" - input: "gradients/mul_grad/tuple/control_dependency_1" - input: "gradients/Log_grad/Reciprocal" + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } attr { key: "T" value { - type: DT_FLOAT + type: DT_INT32 } } attr { @@ -2638,19 +2554,16 @@ meta_graphs { dim { size: -1 } - dim { - size: 10 - } } } } } } node { - name: "gradients/y_grad/mul" - op: "Mul" - input: "gradients/Log_grad/mul" - input: "y" + name: "Reshape_2" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" attr { key: "T" value { @@ -2658,6 +2571,12 @@ meta_graphs { } } attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { key: "_output_shapes" value { list { @@ -2665,16 +2584,13 @@ meta_graphs { dim { size: -1 } - dim { - size: 10 - } } } } } } node { - name: "gradients/y_grad/Sum/reduction_indices" + name: "Const" op: "Const" attr { key: "_output_shapes" @@ -2704,16 +2620,16 @@ meta_graphs { size: 1 } } - int_val: 1 + int_val: 0 } } } } node { - name: "gradients/y_grad/Sum" - op: "Sum" - input: "gradients/y_grad/mul" - input: "gradients/y_grad/Sum/reduction_indices" + name: "Mean" + op: "Mean" + input: "Reshape_2" + input: "Const" attr { key: "T" value { @@ -2731,9 +2647,6 @@ meta_graphs { value { list { shape { - dim { - size: -1 - } } } } @@ -2746,7 +2659,7 @@ meta_graphs { } } node { - name: "gradients/y_grad/Reshape/shape" + name: "gradients/Shape" op: "Const" attr { key: "_output_shapes" @@ -2754,7 +2667,6 @@ meta_graphs { list { shape { dim { - size: 2 } } } @@ -2773,19 +2685,47 @@ meta_graphs { dtype: DT_INT32 tensor_shape { dim { - size: 2 } } - tensor_content: "\377\377\377\377\001\000\000\000" } } } } node { - name: "gradients/y_grad/Reshape" - op: "Reshape" - input: "gradients/y_grad/Sum" - input: "gradients/y_grad/Reshape/shape" + name: "gradients/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" attr { key: "T" value { @@ -2793,32 +2733,56 @@ meta_graphs { } } attr { - key: "Tshape" + key: "_output_shapes" value { - type: DT_INT32 + list { + shape { + } + } } } + } + node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" attr { key: "_output_shapes" value { list { shape { dim { - size: -1 + size: 1 } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { dim { size: 1 } } + int_val: 1 } } } } node { - name: "gradients/y_grad/sub" - op: "Sub" - input: "gradients/Log_grad/mul" - input: "gradients/y_grad/Reshape" + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" attr { key: "T" value { @@ -2826,26 +2790,58 @@ meta_graphs { } } attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { key: "_output_shapes" value { list { shape { dim { - size: -1 + size: 1 } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { dim { - size: 10 + size: 1 } } } } } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } } node { - name: "gradients/y_grad/mul_1" - op: "Mul" - input: "gradients/y_grad/sub" - input: "y" + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Shape" attr { key: "T" value { @@ -2853,6 +2849,12 @@ meta_graphs { } } attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { key: "_output_shapes" value { list { @@ -2860,18 +2862,15 @@ meta_graphs { dim { size: -1 } - dim { - size: 10 - } } } } } } node { - name: "gradients/add_grad/Shape" + name: "gradients/Mean_grad/Shape_1" op: "Shape" - input: "MatMul" + input: "Reshape_2" attr { key: "T" value { @@ -2884,7 +2883,7 @@ meta_graphs { list { shape { dim { - size: 2 + size: 1 } } } @@ -2898,7 +2897,7 @@ meta_graphs { } } node { - name: "gradients/add_grad/Shape_1" + name: "gradients/Mean_grad/Shape_2" op: "Const" attr { key: "_output_shapes" @@ -2906,7 +2905,6 @@ meta_graphs { list { shape { dim { - size: 1 } } } @@ -2925,23 +2923,21 @@ meta_graphs { dtype: DT_INT32 tensor_shape { dim { - size: 1 } } - int_val: 10 } } } } node { - name: "gradients/add_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_grad/Shape" - input: "gradients/add_grad/Shape_1" + name: "gradients/Mean_grad/Const" + op: "Const" attr { - key: "T" + key: "_class" value { - type: DT_INT32 + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } } } attr { @@ -2950,27 +2946,42 @@ meta_graphs { list { shape { dim { - size: -1 + size: 1 } } - shape { + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { dim { - size: -1 + size: 1 } } + int_val: 0 } } } } node { - name: "gradients/add_grad/Sum" - op: "Sum" - input: "gradients/y_grad/mul_1" - input: "gradients/add_grad/BroadcastGradientArgs" + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const" attr { key: "T" value { - type: DT_FLOAT + type: DT_INT32 } } attr { @@ -2980,11 +2991,18 @@ meta_graphs { } } attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { key: "_output_shapes" value { list { shape { - unknown_rank: true } } } @@ -2997,20 +3015,14 @@ meta_graphs { } } node { - name: "gradients/add_grad/Reshape" - op: "Reshape" - input: "gradients/add_grad/Sum" - input: "gradients/add_grad/Shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } + name: "gradients/Mean_grad/Const_1" + op: "Const" attr { - key: "Tshape" + key: "_class" value { - type: DT_INT32 + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } } } attr { @@ -3019,25 +3031,42 @@ meta_graphs { list { shape { dim { - size: -1 + size: 1 } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { dim { - size: 10 + size: 1 } } + int_val: 0 } } } } node { - name: "gradients/add_grad/Sum_1" - op: "Sum" - input: "gradients/y_grad/mul_1" - input: "gradients/add_grad/BroadcastGradientArgs:1" + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_2" + input: "gradients/Mean_grad/Const_1" attr { key: "T" value { - type: DT_FLOAT + type: DT_INT32 } } attr { @@ -3047,11 +3076,18 @@ meta_graphs { } } attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { key: "_output_shapes" value { list { shape { - unknown_rank: true } } } @@ -3064,57 +3100,59 @@ meta_graphs { } } node { - name: "gradients/add_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_grad/Sum_1" - input: "gradients/add_grad/Shape_1" + name: "gradients/Mean_grad/Maximum/y" + op: "Const" attr { - key: "T" + key: "_class" value { - type: DT_FLOAT + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } } } attr { - key: "Tshape" + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" value { type: DT_INT32 } } attr { - key: "_output_shapes" + key: "value" value { - list { - shape { - dim { - size: 10 - } + tensor { + dtype: DT_INT32 + tensor_shape { } + int_val: 1 } } } } node { - name: "gradients/add_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_grad/Reshape" - input: "^gradients/add_grad/Reshape_1" - } - node { - name: "gradients/add_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_grad/Reshape" - input: "^gradients/add_grad/tuple/group_deps" + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" attr { key: "T" value { - type: DT_FLOAT + type: DT_INT32 } } attr { key: "_class" value { list { - s: "loc:@gradients/add_grad/Reshape" + s: "loc:@gradients/Mean_grad/Shape_1" } } } @@ -3123,33 +3161,27 @@ meta_graphs { value { list { shape { - dim { - size: -1 - } - dim { - size: 10 - } } } } } } node { - name: "gradients/add_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_grad/Reshape_1" - input: "^gradients/add_grad/tuple/group_deps" + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" attr { key: "T" value { - type: DT_FLOAT + type: DT_INT32 } } attr { key: "_class" value { list { - s: "loc:@gradients/add_grad/Reshape_1" + s: "loc:@gradients/Mean_grad/Shape_1" } } } @@ -3158,58 +3190,65 @@ meta_graphs { value { list { shape { - dim { - size: 10 - } } } } } } node { - name: "gradients/MatMul_grad/MatMul" - op: "MatMul" - input: "gradients/add_grad/tuple/control_dependency" - input: "Variable/read" + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" attr { - key: "T" + key: "DstT" value { type: DT_FLOAT } } attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { key: "_output_shapes" value { list { shape { - dim { - size: -1 - } - dim { - size: 784 - } } } } } + } + node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" attr { - key: "transpose_a" + key: "T" value { - b: false + type: DT_FLOAT } } attr { - key: "transpose_b" + key: "_output_shapes" value { - b: true + list { + shape { + dim { + size: -1 + } + } + } } } } node { - name: "gradients/MatMul_grad/MatMul_1" - op: "MatMul" - input: "x" - input: "gradients/add_grad/tuple/control_dependency" + name: "gradients/Reshape_2_grad/Shape" + op: "Shape" + input: "SoftmaxCrossEntropyWithLogits" attr { key: "T" value { @@ -3222,39 +3261,24 @@ meta_graphs { list { shape { dim { - size: 784 - } - dim { - size: 10 + size: 1 } } } } } attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" + key: "out_type" value { - b: false + type: DT_INT32 } } } node { - name: "gradients/MatMul_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_grad/MatMul" - input: "^gradients/MatMul_grad/MatMul_1" - } - node { - name: "gradients/MatMul_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/MatMul_grad/MatMul" - input: "^gradients/MatMul_grad/tuple/group_deps" + name: "gradients/Reshape_2_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_2_grad/Shape" attr { key: "T" value { @@ -3262,11 +3286,9 @@ meta_graphs { } } attr { - key: "_class" + key: "Tshape" value { - list { - s: "loc:@gradients/MatMul_grad/MatMul" - } + type: DT_INT32 } } attr { @@ -3277,19 +3299,15 @@ meta_graphs { dim { size: -1 } - dim { - size: 784 - } } } } } } node { - name: "gradients/MatMul_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_grad/MatMul_1" - input: "^gradients/MatMul_grad/tuple/group_deps" + name: "gradients/zeros_like" + op: "ZerosLike" + input: "SoftmaxCrossEntropyWithLogits:1" attr { key: "T" value { @@ -3297,23 +3315,15 @@ meta_graphs { } } attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_grad/MatMul_1" - } - } - } - attr { key: "_output_shapes" value { list { shape { dim { - size: 784 + size: -1 } dim { - size: 10 + size: -1 } } } @@ -3321,7 +3331,7 @@ meta_graphs { } } node { - name: "GradientDescent/learning_rate" + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" op: "Const" attr { key: "_output_shapes" @@ -3335,27 +3345,26 @@ meta_graphs { attr { key: "dtype" value { - type: DT_FLOAT + type: DT_INT32 } } attr { key: "value" value { tensor { - dtype: DT_FLOAT + dtype: DT_INT32 tensor_shape { } - float_val: 0.00999999977648 + int_val: -1 } } } } node { - name: "GradientDescent/update_Variable/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "Variable" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_grad/tuple/control_dependency_1" + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_2_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" attr { key: "T" value { @@ -3363,11 +3372,9 @@ meta_graphs { } } attr { - key: "_class" + key: "Tdim" value { - list { - s: "loc:@Variable" - } + type: DT_INT32 } } attr { @@ -3376,28 +3383,21 @@ meta_graphs { list { shape { dim { - size: 784 + size: -1 } dim { - size: 10 + size: 1 } } } } } - attr { - key: "use_locking" - value { - b: false - } - } } node { - name: "GradientDescent/update_Variable_1/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "Variable_1" - input: "GradientDescent/learning_rate" - input: "gradients/add_grad/tuple/control_dependency_1" + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "SoftmaxCrossEntropyWithLogits:1" attr { key: "T" value { @@ -3405,73 +3405,87 @@ meta_graphs { } } attr { - key: "_class" + key: "_output_shapes" value { list { - s: "loc:@Variable_1" + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } } } } + } + node { + name: "gradients/Reshape_grad/Shape" + op: "Shape" + input: "add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } attr { key: "_output_shapes" value { list { shape { dim { - size: 10 + size: 2 } } } } } attr { - key: "use_locking" + key: "out_type" value { - b: false + type: DT_INT32 } } } node { - name: "GradientDescent" - op: "NoOp" - input: "^GradientDescent/update_Variable/ApplyGradientDescent" - input: "^GradientDescent/update_Variable_1/ApplyGradientDescent" - } - node { - name: "TopKV2/k" - op: "Const" + name: "gradients/Reshape_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_grad/Shape" attr { - key: "_output_shapes" + key: "T" value { - list { - shape { - } - } + type: DT_FLOAT } } attr { - key: "dtype" + key: "Tshape" value { type: DT_INT32 } } attr { - key: "value" + key: "_output_shapes" value { - tensor { - dtype: DT_INT32 - tensor_shape { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } } - int_val: 10 } } } } node { - name: "TopKV2" - op: "TopKV2" - input: "y" - input: "TopKV2/k" + name: "gradients/add_grad/Shape" + op: "Shape" + input: "MatMul" attr { key: "T" value { @@ -3484,32 +3498,21 @@ meta_graphs { list { shape { dim { - size: -1 - } - dim { - size: 10 - } - } - shape { - dim { - size: -1 - } - dim { - size: 10 + size: 2 } } } } } attr { - key: "sorted" + key: "out_type" value { - b: true + type: DT_INT32 } } } node { - name: "Const_1" + name: "gradients/add_grad/Shape_1" op: "Const" attr { key: "_output_shapes" @@ -3517,7 +3520,7 @@ meta_graphs { list { shape { dim { - size: 10 + size: 1 } } } @@ -3526,133 +3529,207 @@ meta_graphs { attr { key: "dtype" value { - type: DT_STRING + type: DT_INT32 } } attr { key: "value" value { tensor { - dtype: DT_STRING + dtype: DT_INT32 tensor_shape { dim { - size: 10 + size: 1 } } - string_val: "0" - string_val: "1" - string_val: "2" - string_val: "3" - string_val: "4" - string_val: "5" - string_val: "6" - string_val: "7" - string_val: "8" - string_val: "9" + int_val: 10 } } } } node { - name: "index_to_string/Size" - op: "Const" + name: "gradients/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_grad/Shape" + input: "gradients/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } attr { key: "_output_shapes" value { list { shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } } } } } + } + node { + name: "gradients/add_grad/Sum" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/add_grad/BroadcastGradientArgs" attr { - key: "dtype" + key: "T" value { - type: DT_INT32 + type: DT_FLOAT } } attr { - key: "value" + key: "Tidx" value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 10 - } + type: DT_INT32 } } - } - node { - name: "index_to_string/range/start" - op: "Const" attr { key: "_output_shapes" value { list { shape { + unknown_rank: true } } } } attr { - key: "dtype" + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/add_grad/Reshape" + op: "Reshape" + input: "gradients/add_grad/Sum" + input: "gradients/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" value { type: DT_INT32 } } attr { - key: "value" + key: "_output_shapes" value { - tensor { - dtype: DT_INT32 - tensor_shape { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } } - int_val: 0 } } } } node { - name: "index_to_string/range/delta" - op: "Const" + name: "gradients/add_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } attr { key: "_output_shapes" value { list { shape { + unknown_rank: true } } } } attr { - key: "dtype" + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_grad/Sum_1" + input: "gradients/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" value { type: DT_INT32 } } attr { - key: "value" + key: "_output_shapes" value { - tensor { - dtype: DT_INT32 - tensor_shape { + list { + shape { + dim { + size: 10 + } } - int_val: 1 } } } } node { - name: "index_to_string/range" - op: "Range" - input: "index_to_string/range/start" - input: "index_to_string/Size" - input: "index_to_string/range/delta" + name: "gradients/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_grad/Reshape" + input: "^gradients/add_grad/Reshape_1" + } + node { + name: "gradients/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_grad/Reshape" + input: "^gradients/add_grad/tuple/group_deps" attr { - key: "Tidx" + key: "T" value { - type: DT_INT32 + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape" + } } } attr { @@ -3661,6 +3738,9 @@ meta_graphs { list { shape { dim { + size: -1 + } + dim { size: 10 } } @@ -3669,19 +3749,22 @@ meta_graphs { } } node { - name: "index_to_string/ToInt64" - op: "Cast" - input: "index_to_string/range" + name: "gradients/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_grad/Reshape_1" + input: "^gradients/add_grad/tuple/group_deps" attr { - key: "DstT" + key: "T" value { - type: DT_INT64 + type: DT_FLOAT } } attr { - key: "SrcT" + key: "_class" value { - type: DT_INT32 + list { + s: "loc:@gradients/add_grad/Reshape_1" + } } } attr { @@ -3698,111 +3781,207 @@ meta_graphs { } } node { - name: "index_to_string" - op: "HashTableV2" + name: "gradients/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/add_grad/tuple/control_dependency" + input: "Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } attr { key: "_output_shapes" value { list { shape { + dim { + size: -1 + } + dim { + size: 784 + } } } } } attr { - key: "container" + key: "transpose_a" value { - s: "" + b: false } } attr { - key: "key_dtype" + key: "transpose_b" value { - type: DT_INT64 + b: true } } + } + node { + name: "gradients/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Placeholder" + input: "gradients/add_grad/tuple/control_dependency" attr { - key: "shared_name" + key: "T" value { - s: "" + type: DT_FLOAT } } attr { - key: "use_node_name_sharing" + key: "_output_shapes" value { - b: false + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } } } attr { - key: "value_dtype" + key: "transpose_a" value { - type: DT_STRING + b: true + } + } + attr { + key: "transpose_b" + value { + b: false } } } node { - name: "index_to_string/Const" - op: "Const" + name: "gradients/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/MatMul_1" + } + node { + name: "gradients/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul" + } + } + } attr { key: "_output_shapes" value { list { shape { + dim { + size: -1 + } + dim { + size: 784 + } } } } } + } + node { + name: "gradients/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_grad/MatMul_1" + input: "^gradients/MatMul_grad/tuple/group_deps" attr { - key: "dtype" + key: "T" value { - type: DT_STRING + type: DT_FLOAT } } attr { - key: "value" + key: "_class" value { - tensor { - dtype: DT_STRING - tensor_shape { + list { + s: "loc:@gradients/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } } - string_val: "UNK" } } } } node { - name: "index_to_string/table_init" - op: "InitializeTableV2" - input: "index_to_string" - input: "index_to_string/ToInt64" - input: "Const_1" + name: "GradientDescent/learning_rate" + op: "Const" attr { - key: "Tkey" + key: "_output_shapes" value { - type: DT_INT64 + list { + shape { + } + } } } attr { - key: "Tval" + key: "dtype" value { - type: DT_STRING + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } } } } node { - name: "ToInt64" - op: "Cast" - input: "TopKV2:1" + name: "GradientDescent/update_Variable/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "Variable" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_grad/tuple/control_dependency_1" attr { - key: "DstT" + key: "T" value { - type: DT_INT64 + type: DT_FLOAT } } attr { - key: "SrcT" + key: "_class" value { - type: DT_INT32 + list { + s: "loc:@Variable" + } } } attr { @@ -3811,7 +3990,7 @@ meta_graphs { list { shape { dim { - size: -1 + size: 784 } dim { size: 10 @@ -3820,23 +3999,31 @@ meta_graphs { } } } + attr { + key: "use_locking" + value { + b: false + } + } } node { - name: "index_to_string_Lookup" - op: "LookupTableFindV2" - input: "index_to_string" - input: "ToInt64" - input: "index_to_string/Const" + name: "GradientDescent/update_Variable_1/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "Variable_1" + input: "GradientDescent/learning_rate" + input: "gradients/add_grad/tuple/control_dependency_1" attr { - key: "Tin" + key: "T" value { - type: DT_INT64 + type: DT_FLOAT } } attr { - key: "Tout" + key: "_class" value { - type: DT_STRING + list { + s: "loc:@Variable_1" + } } } attr { @@ -3845,15 +4032,30 @@ meta_graphs { list { shape { dim { - size: -1 - } - dim { size: 10 } } } } } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_Variable/ApplyGradientDescent" + input: "^GradientDescent/update_Variable_1/ApplyGradientDescent" + } + node { + name: "init" + op: "NoOp" + input: "^Variable/Assign" + input: "^Variable_1/Assign" } node { name: "ArgMax/dimension" @@ -3888,7 +4090,7 @@ meta_graphs { node { name: "ArgMax" op: "ArgMax" - input: "y" + input: "add" input: "ArgMax/dimension" attr { key: "T" @@ -3954,7 +4156,7 @@ meta_graphs { node { name: "ArgMax_1" op: "ArgMax" - input: "Placeholder" + input: "Placeholder_1" input: "ArgMax_1/dimension" attr { key: "T" @@ -4012,7 +4214,7 @@ meta_graphs { } } node { - name: "Cast" + name: "Cast_1" op: "Cast" input: "Equal" attr { @@ -4041,7 +4243,7 @@ meta_graphs { } } node { - name: "Const_2" + name: "Const_1" op: "Const" attr { key: "_output_shapes" @@ -4077,10 +4279,10 @@ meta_graphs { } } node { - name: "Mean" + name: "Mean_1" op: "Mean" - input: "Cast" - input: "Const_2" + input: "Cast_1" + input: "Const_1" attr { key: "T" value { @@ -4110,16 +4312,6 @@ meta_graphs { } } node { - name: "init_all_tables" - op: "NoOp" - input: "^index_to_string/table_init" - } - node { - name: "legacy_init_op" - op: "NoOp" - input: "^init_all_tables" - } - node { name: "save/Const" op: "Const" attr { @@ -4174,7 +4366,7 @@ meta_graphs { dtype: DT_STRING tensor_shape { } - string_val: "_temp_8390c48b96834292ab57050d0ae6959e/part" + string_val: "_temp_6ca9fa5171ed4237a2fbcc27277e2864/part" } } } @@ -4783,22 +4975,6 @@ meta_graphs { version: V2 } collection_def { - key: "legacy_init_op" - value { - node_list { - value: "legacy_init_op" - } - } - } - collection_def { - key: "table_initializer" - value { - node_list { - value: "index_to_string/table_init" - } - } - } - collection_def { key: "train_op" value { node_list { @@ -4810,8 +4986,8 @@ meta_graphs { key: "trainable_variables" value { bytes_list { - value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:0" - value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:0" + value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0" + value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0" } } } @@ -4819,18 +4995,18 @@ meta_graphs { key: "variables" value { bytes_list { - value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:0" - value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:0" + value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0" + value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0" } } } signature_def { - key: "predict_images" + key: "serving_default" value { inputs { - key: "images" + key: "x" value { - name: "x:0" + name: "Placeholder:0" dtype: DT_FLOAT tensor_shape { dim { @@ -4843,9 +5019,9 @@ meta_graphs { } } outputs { - key: "scores" + key: "y" value { - name: "y:0" + name: "add:0" dtype: DT_FLOAT tensor_shape { dim { @@ -4860,50 +5036,4 @@ meta_graphs { method_name: "tensorflow/serving/predict" } } - signature_def { - key: "serving_default" - value { - inputs { - key: "inputs" - value { - name: "tf_example:0" - dtype: DT_STRING - tensor_shape { - unknown_rank: true - } - } - } - outputs { - key: "classes" - value { - name: "index_to_string_Lookup:0" - dtype: DT_STRING - tensor_shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - outputs { - key: "scores" - value { - name: "TopKV2:0" - dtype: DT_FLOAT - tensor_shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - method_name: "tensorflow/serving/classify" - } - } } diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 Binary files differnew file mode 100644 index 00000000000..8474aa0a04c --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index Binary files differnew file mode 100644 index 00000000000..cfcdac20409 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index diff --git a/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.data-00000-of-00001 Binary files differdeleted file mode 100644 index ba71c21fbe1..00000000000 --- a/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.data-00000-of-00001 +++ /dev/null diff --git a/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.index Binary files differdeleted file mode 100644 index 84f4593515a..00000000000 --- a/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.index +++ /dev/null diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 82e5d0cfe5b..3aa2d144f1f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.*; -import com.yahoo.tensor.Tensor; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.IfNode; import org.junit.Test; + import static org.junit.Assert.assertEquals; /** @@ -83,7 +88,7 @@ public class EvaluationTestCase { tester.assertEvaluates(0, "sin(0)"); tester.assertEvaluates(1, "cos(0)"); tester.assertEvaluates(8, "pow(4/2,min(cos(0)*3,5))"); - + // Random feature (which is also a tensor function) (We expect to be able to parse it and look up a zero) tester.assertEvaluates(0, "random(1)"); tester.assertEvaluates(0, "random(foo)"); @@ -152,7 +157,7 @@ public class EvaluationTestCase { "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", "!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }"); - + // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence) tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }"); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java index ee2b1c147e3..ba0db4de5e1 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java @@ -34,7 +34,7 @@ public class EvaluationTester { } // TODO: Test both bound and unbound indexed - public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors, + public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors, String ... tensorArgumentStrings) { MapContext context = defaultContext.thawedCopy(); int argumentIndex = 0; @@ -46,7 +46,7 @@ public class EvaluationTester { argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString); context.put("tensor" + (argumentIndex++), new TensorValue(argument)); } - return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context, + return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context, mappedTensors ? "Mapped tensors" : "Indexed tensors"); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java new file mode 100644 index 00000000000..dab42801d70 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java @@ -0,0 +1,114 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; + +import java.nio.FloatBuffer; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author bratseth + */ +public class Mnist_SoftmaxTestCase { + + @Test + public void testImporting() { + String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved"; + ImportResult result = new TensorFlowImporter().importModel(modelDir); + + // Check logged messages + result.warnings().forEach(System.err::println); + assertEquals(0, result.warnings().size()); + + // Check arguments + assertEquals(1, result.arguments().size()); + TensorType argument0 = result.arguments().get("Placeholder"); + assertNotNull(argument0); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); + + // Check constants + assertEquals(2, result.constants().size()); + + Tensor constant0 = result.constants().get("Variable"); + assertNotNull(constant0); + assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), + constant0.type()); + assertEquals(7840, constant0.size()); + + Tensor constant1 = result.constants().get("Variable_1"); + assertNotNull(constant1); + assertEquals(new TensorType.Builder().indexed("d0", 10).build(), + constant1.type()); + assertEquals(10, constant1.size()); + + // Check resulting Vespa expression + assertEquals(1, result.expressions().size()); + assertEquals("y", result.expressions().get(0).getName()); + assertEquals("" + + "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " + + "rename(constant(Variable_1), d0, d1), " + + "f(a,b)(a + b))", + toNonPrimitiveString(result.expressions().get(0))); + + // Test execution + String signatureName = "serving_default"; + + assertEqualResult(modelDir, signatureName, "Variable/read"); + assertEqualResult(modelDir, signatureName, "Variable_1/read"); + // TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder"); + assertEqualResult(modelDir, signatureName, "MatMul"); + assertEqualResult(modelDir, signatureName, "add"); + } + + private void assertEqualResult(String modelDir, String signatureName, String operationName) { + ImportResult result = new TensorFlowImporter().importNode(modelDir, signatureName, operationName); + + Tensor tfResult = tensorFlowExecute(modelDir, operationName); + Context context = contextFrom(result); + Tensor placeholder = placeholderArgument(); + context.put("Placeholder", new TensorValue(placeholder)); + Tensor vespaResult = result.expressions().get(0).evaluate(context).asTensor(); + assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult); + } + + private Tensor tensorFlowExecute(String modelDir, String operationName) { + SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); + Session.Runner runner = model.session().runner(); + org.tensorflow.Tensor placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784)); + runner.feed("Placeholder", placeholder); + List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); + assertEquals(1, results.size()); + return new TensorConverter().toVespaTensor(results.get(0)); + } + + private Context contextFrom(ImportResult result) { + MapContext context = new MapContext(); + result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + return context; + } + + private String toNonPrimitiveString(RankingExpression expression) { + // toString on the wrapping expression will map to primitives, which is harder to read + return ((TensorFunctionNode)expression.getRoot()).function().toString(); + } + + private Tensor placeholderArgument() { + int size = 784; + Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build()); + for (int i = 0; i < size; i++) + b.cell(0, 0, i); + return b.build(); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java deleted file mode 100644 index aaf198a9e8f..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java +++ /dev/null @@ -1,79 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.TensorType; -import org.junit.Test; - -import java.util.ArrayList; -import java.util.List; -import java.util.logging.Level; -import java.util.stream.Collectors; - -import static org.junit.Assert.assertEquals; - -/** - * @author bratseth - */ -public class TensorFlowImporterTestCase { - - @Test - public void testModel1() { - List<NamedTensor> constants = new ArrayList<>(); - TestLogger logger = new TestLogger(); - List<RankingExpression> expressions = - new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/", constants, logger); - - // Check constants - assertEquals(2, constants.size()); - - assertEquals("Variable", constants.get(0).name()); - assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), - constants.get(0).tensor().type()); - assertEquals(7840, constants.get(0).tensor().size()); - - assertEquals("Variable_1", constants.get(1).name()); - assertEquals(new TensorType.Builder().indexed("d0", 10).build(), - constants.get(1).tensor().type()); - assertEquals(10, constants.get(1).tensor().size()); - - // Check logged messages - assertEquals(2, logger.messages().size()); - assertEquals("Skipping output 'TopKV2:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'TopKV2' is not supported", - logger.messages().get(0)); - assertEquals("Skipping output 'index_to_string_Lookup:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'LookupTableFindV2' is not supported", - logger.messages().get(1)); - - // Check resulting Vespa expression - assertEquals(1, expressions.size()); - assertEquals("scores", expressions.get(0).getName()); - assertEquals("" + - "softmax(join(rename(matmul(x, rename(constant(Variable), (d1, d2), (d2, d3)), d2), d3, d2), " + - "constant(Variable_1), " + - "f(a,b)(a + b)), " + - "d0)", - toNonPrimitiveString(expressions.get(0))); - } - - private String toNonPrimitiveString(RankingExpression expression) { - // toString on the wrapping expression will map to primitives, which is harder to read - return ((TensorFunctionNode)expression.getRoot()).function().toString(); - } - - private class TestLogger implements TensorFlowImporter.MessageLogger { - - private List<String> messages = new ArrayList<>(); - - /** Returns the messages in sorted order */ - public List<String> messages() { - return messages.stream().sorted().collect(Collectors.toList()); - } - - @Override - public void log(Level level, String message) { - messages.add(message); - } - - } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java index dde9d4bf21e..1960c1fe876 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java @@ -59,7 +59,7 @@ public class TensorConformanceTest { try { ObjectMapper mapper = new ObjectMapper(); JsonNode node = mapper.readTree(test); - + if (node.has("num_tests")) { Assert.assertEquals(node.get("num_tests").asInt(), count); return true; @@ -67,7 +67,7 @@ public class TensorConformanceTest { if (!node.has("expression")) { return true; // ignore } - + String expression = node.get("expression").asText(); MapContext context = getInput(node.get("inputs")); Tensor expect = getTensor(node.get("result").get("expect").asText()); |