summaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2017-12-18 09:14:37 +0100
committerGitHub <noreply@github.com>2017-12-18 09:14:37 +0100
commit9347da6b81bd1f723d754fea2add617268ea90fa (patch)
tree6b6489e089b2ff9c5d67599d6be55d694c9ee99b /searchlib/src
parentdbf5328bdc8daed3e4111742e2f6e0a48277e3d3 (diff)
Revert "Revert "Bratseth/tensorflow models""
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java44
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java51
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java160
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java94
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java147
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java36
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py89
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt5039
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001bin0 -> 31400 bytes
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.indexbin0 -> 159 bytes
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java13
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java114
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java4
21 files changed, 5817 insertions, 59 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 785ed78492e..0eeb0a9e630 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Set;
@@ -18,26 +19,30 @@ public abstract class Context implements EvaluationContext {
/**
* <p>Returns the value of a simple variable name.</p>
*
- * @param name The name of the variable whose value to return.
- * @return The value of the named variable.
+ * @param name the name of the variable whose value to return.
+ * @return the value of the named variable.
*/
public abstract Value get(String name);
+ /** Returns a variable as a tensor */
+ @Override
+ public Tensor getTensor(String name) { return get(name).asTensor(); }
+
/**
* <p>Returns the value of a <i>structured variable</i> on the form
* <code>name(argument*)(.output)?</code>, where <i>argument</i> is any
* string. This may be used to implement more advanced variables whose
* values are calculated at runtime from arguments. Supporting this in a
- * context is optional.
- *
+ * context is optional.
+ *
* <p>This default implementation generates a name on the form
* <code>name(argument1, argument2, ...argumentN).output</code>.
* If there are no arguments the parenthesis are omitted.
* If there is no output, the dot is omitted.</p>
*
- * @param name The name of this variable.
- * @param arguments The parsed arguments as given in the textual expression.
- * @param output The name of the value to output (to enable one named
+ * @param name the name of this variable.
+ * @param arguments the parsed arguments as given in the textual expression.
+ * @param output the name of the value to output (to enable one named
* calculation to output several), or null to output the
* "main" (or only) value.
*/
@@ -54,20 +59,20 @@ public abstract class Context implements EvaluationContext {
* context subclasses. This default implementation throws
* UnsupportedOperationException.</p>
*
- * @param index The index of the variable whose value to return.
- * @return The value of the indexed variable.
+ * @param index the index of the variable whose value to return.
+ * @return the value of the indexed variable.
*/
public Value get(int index) {
throw new UnsupportedOperationException(this + " does not support variable lookup by index");
}
/**
- * <p>Lookup by index rather than name directly to a double. This is supported by some optimized
+ * Lookup by index rather than name directly to a double. This is supported by some optimized
* context subclasses. This default implementation throws
- * UnsupportedOperationException.</p>
+ * UnsupportedOperationException.
*
- * @param index The index of the variable whose value to return.
- * @return The value of the indexed variable.
+ * @param index the index of the variable whose value to return.
+ * @return the value of the indexed variable.
*/
public double getDouble(int index) {
throw new UnsupportedOperationException(this + " does not support variable lookup by index");
@@ -81,24 +86,23 @@ public abstract class Context implements EvaluationContext {
}
/**
- * <p>Sets a value to this, or throws an UnsupportedOperationException if
- * this is not supported. This default implementation does the latter.</p> *
+ * Sets a value to this, or throws an UnsupportedOperationException if
+ * this is not supported. This default implementation does the latter.
*
- * @param name The name of the variable to set.
+ * @param name the name of the variable to set.
* @param value the value to set. Ownership of this value is transferred to this - if it is mutable
* (not frozen) it may be modified during execution
- * @since 5.1.5
*/
public void put(String name, Value value) {
throw new UnsupportedOperationException(this + " does not support variable assignment");
}
/**
- * <p>Returns all the names available in this, or throws an
+ * Returns all the names available in this, or throws an
* UnsupportedOperationException if this operation is not supported. This
- * default implementation does the latter.</p>
+ * default implementation does the latter.
*
- * @return The set of all variable names.
+ * @return the set of all variable names.
*/
public Set<String> names() {
throw new UnsupportedOperationException(this + " does not support return a list of its names");
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index ea750295423..2ef4a2ede2f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -3,6 +3,9 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* A value which acts as a double in numerical context.
@@ -16,6 +19,11 @@ public abstract class DoubleCompatibleValue extends Value {
public boolean hasDouble() { return true; }
@Override
+ public Tensor asTensor() {
+ return doubleAsTensor(asDouble());
+ }
+
+ @Override
public Value negate() { return new DoubleValue(-asDouble()); }
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index ac8aba6a617..dad69b31181 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -4,12 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* A string value.
*
* @author bratseth
- * @since 5.1.21
*/
public class StringValue extends Value {
@@ -35,6 +37,11 @@ public class StringValue extends Value {
}
@Override
+ public Tensor asTensor() {
+ return doubleAsTensor(asDouble());
+ }
+
+ @Override
public boolean hasDouble() { return true; }
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 49c3ccb7b01..26c30fe5ed2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -2,14 +2,10 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.google.common.annotations.Beta;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
-import com.yahoo.tensor.TensorType;
-
-import java.util.Collections;
-import java.util.Optional;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
/**
* A Value containing a tensor.
@@ -23,7 +19,7 @@ public class TensorValue extends Value {
/** The tensor value of this */
private final Tensor value;
-
+
public TensorValue(Tensor value) {
this.value = value;
}
@@ -131,7 +127,7 @@ public class TensorValue extends Value {
public Value compare(TruthOperator operator, Value argument) {
return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString())));
}
-
+
private Tensor compareTensor(TruthOperator operator, Tensor argument) {
switch (operator) {
case LARGER: return value.larger(argument);
@@ -152,7 +148,7 @@ public class TensorValue extends Value {
else
return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble())));
}
-
+
private Tensor functionOnTensor(Function function, Tensor argument) {
switch (function) {
case min: return value.min(argument);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index b2ccbe572d0..40d70e0022c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -5,6 +5,8 @@ import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* The result of a ranking expression evaluation.
@@ -25,6 +27,14 @@ public abstract class Value {
return new DoubleValue(asDouble());
}
+ /** Returns this as a tensor value */
+ public abstract Tensor asTensor();
+
+ /** A utility method for wrapping a sdouble in a rank 0 tensor */
+ protected Tensor doubleAsTensor(double value) {
+ return Tensor.Builder.of(TensorType.empty).cell(TensorAddress.of(), value).build();
+ }
+
/** Returns true if this value can return itself as a double, i.e asDoubleValue will return a value and not throw */
public abstract boolean hasDouble();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
new file mode 100644
index 00000000000..b4a9b363ade
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
@@ -0,0 +1,51 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * The result of importing a TensorFlow model into Vespa:
+ * - A list of ranking expressions reproducing the computations of the outputs in the TensorFlow model
+ * - A list of named constant tensors
+ * - A list of expected input tensors, with their tensor type
+ * - A list of warning messages
+ *
+ * @author bratseth
+ */
+// This object can be built incrementally within this package, but is immutable when observed from outside the package
+// TODO: Retain signature structure in ImportResult (input + output-expression bundles)
+public class ImportResult {
+
+ private final List<RankingExpression> expressions = new ArrayList<>();
+ private final Map<String, Tensor> constants = new HashMap<>();
+ private final Map<String, TensorType> arguments = new HashMap<>();
+ private final List<String> warnings = new ArrayList<>();
+
+ void add(RankingExpression expression) { expressions.add(expression); }
+ void set(String name, Tensor constant) { constants.put(name, constant); }
+ void set(String name, TensorType argument) { arguments.put(name, argument); }
+ void warn(String warning) { warnings.add(warning); }
+
+ /** Returns an immutable list of the expressions of this */
+ public List<RankingExpression> expressions() { return Collections.unmodifiableList(expressions); }
+
+ /** Returns an immutable map of the constants of this */
+ public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); }
+
+ /** Returns an immutable map of the arguments of this */
+ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
+
+ /** Returns an immutable list, in natural sort order of the warnings generated while importing this */
+ public List<String> warnings() {
+ return warnings.stream().sorted().collect(Collectors.toList());
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
new file mode 100644
index 00000000000..e7f7b5ef2f4
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -0,0 +1,160 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Matmul;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.Softmax;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
+
+/**
+ * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions.
+ *
+ * @author bratseth
+ */
+class OperationMapper {
+
+ /*
+ A note on conversion from implicitly numbered to explicitly named dimensions:
+ Vespa tensor dimensions are explicitly named and thus have an explicit notion of being
+ 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation
+ comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation
+ around dimension renaming operations which mirrors those built into the TF operation definitions.
+
+ To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost'
+ dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation
+ and the result is then renamed again (if necessary) to recover this convention across a full nested
+ computation.
+
+ This requires us to track tensor types throughout the conversion.
+ */
+
+ private TensorConverter tensorConverter = new TensorConverter();
+
+ TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) {
+ ensureArguments(2, arguments, "join");
+ TypedTensorFunction a = arguments.get(0);
+ TypedTensorFunction b = arguments.get(1);
+ if (a.type().rank() < b.type().rank())
+ throw new IllegalArgumentException("Attempt to join " + a.type() + " and " + b.type() + ", " +
+ "but this is not supported when the second argument has a higher rank");
+
+ TensorFunction bFunction = b.function();
+
+ if (a.type().rank() > b.type().rank()) {
+ // Well now we have entered the wonderful world of "broadcasting"
+ // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ // I'm not able to extract from that any unambiguous specification of which dimensions
+ // should be "stretched" when the tensor do not have the same number of dimensions.
+ // From trying this with TensorFlow it appears that the second tensor is matched to the
+ // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
+ // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).
+ List<String> renameFrom = new ArrayList<>();
+ List<String> renameTo = new ArrayList<>();
+ int sizeDifference = a.type().rank() - b.type().rank();
+ for (int i = 0; i < b.type().rank(); i++) {
+ renameFrom.add(b.type().dimensions().get(i).name());
+ renameTo.add("d" + (sizeDifference + i));
+ }
+ bFunction = new Rename(bFunction, renameFrom, renameTo);
+ }
+
+ Join function = new Join(a.function(), bFunction, doubleFunction);
+ return new TypedTensorFunction(a.type(), function); // output type is a type by TF definition and a.rank>=b.rank
+ }
+
+ TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) {
+ ensureArguments(1, arguments, "apply");
+ TypedTensorFunction a = arguments.get(0);
+
+ TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type());
+ com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction);
+ return new TypedTensorFunction(resultType, function);
+ }
+
+ TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) {
+ String name = tfNode.getName();
+ TensorType type = result.arguments().get(name);
+ if (type == null)
+ throw new IllegalArgumentException("An placeholder operation node is referencing input '" + name +
+ "', but there is no such input");
+ // Included literally in the expression and so must be produced by a separate macro in the rank profile
+ return new TypedTensorFunction(type, new VariableTensor(name));
+ }
+
+ TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) {
+ if ( ! tfNode.getName().endsWith("/read"))
+ throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " +
+ "nodes are only supported when reading variables");
+ if (tfNode.getInputList().size() != 1)
+ throw new IllegalArgumentException("A Variable/read node must have one input but has " +
+ tfNode.getInputList().size());
+
+ String name = tfNode.getInput(0);
+ AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
+ if (shapes == null)
+ throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape");
+ Session.Runner fetched = model.session().runner().fetch(name);
+ List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
+ if ( importedTensors.size() != 1)
+ throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " +
+ importedTensors.size());
+ Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0));
+ result.set(name, constant);
+ return new TypedTensorFunction(constant.type(),
+ new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
+ }
+
+ TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
+ ensureArguments(2, arguments, "matmul");
+ TypedTensorFunction a = arguments.get(0);
+ TypedTensorFunction b = arguments.get(1);
+ if (a.type().rank() < 2 || b.type().rank() < 2)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
+ if (a.type().rank() != b.type().rank())
+ throw new IllegalArgumentException("Tensors in matmul must have the same rank");
+
+ String afterLastDim = "d" + (a.type().rank() + 1);
+ // Let the first dimension of the second tensor be the same as the second dimension of the first
+ // and the second dimension of the second argument be not present in the first argument, while leaving the
+ // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication.
+
+ // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly
+
+ Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"),
+ ImmutableList.of("d1", afterLastDim));
+ Matmul matmul = new Matmul(a.function(), renamedB, "d1");
+ return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"),
+ new Rename(matmul, afterLastDim, "d1"));
+ }
+
+ TypedTensorFunction softmax(List<TypedTensorFunction> arguments) {
+ ensureArguments(1, arguments, "softmax");
+ TypedTensorFunction a = arguments.get(0);
+ // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1
+ String dimension = "d" + (a.type().rank() - 1);
+ Softmax softmax = new Softmax(a.function(), dimension);
+ return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax);
+ }
+
+ private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) {
+ if ( arguments.size() != count)
+ throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName +
+ ", but got " + arguments.size());
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
new file mode 100644
index 00000000000..df43225c333
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
@@ -0,0 +1,94 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+
+/**
+ * @author bratseth
+ */
+public class TensorConverter {
+
+ public Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
+ TensorType type = toVespaTensorType(tfTensor.shape());
+ Values values = readValuesOf(tfTensor);
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
+ for (int i = 0; i < values.size(); i++)
+ builder.cellByDirectIndex(i, values.get(i));
+ return builder.build();
+ }
+
+ private TensorType toVespaTensorType(long[] shape) {
+ TensorType.Builder b = new TensorType.Builder();
+ int dimensionIndex = 0;
+ for (long dimensionSize : shape) {
+ if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
+ b.indexed("d" + (dimensionIndex++), (int) dimensionSize);
+ }
+ return b.build();
+ }
+
+ private Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
+ switch (tfTensor.dataType()) {
+ case DOUBLE: return new DoubleValues(tfTensor);
+ case FLOAT: return new FloatValues(tfTensor);
+ // TODO: The rest
+ default:
+ throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
+ tfTensor.dataType() + " to a Vespa tensor");
+ }
+ }
+
+ /** Allows reading values from buffers of various numeric types as bytes */
+ private static abstract class Values {
+
+ private final int size;
+
+ protected Values(int size) {
+ this.size = size;
+ }
+
+ abstract double get(int i);
+
+ int size() { return size; }
+
+ }
+
+ private static class DoubleValues extends Values {
+
+ private final DoubleBuffer values;
+
+ DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = DoubleBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+
+ @Override
+ double get(int i) {
+ return values.get(i);
+ }
+
+ }
+
+ private static class FloatValues extends Values {
+
+ private final FloatBuffer values;
+
+ FloatValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = FloatBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+
+ @Override
+ double get(int i) {
+ return values.get(i);
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
new file mode 100644
index 00000000000..33523244129
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -0,0 +1,147 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.yolean.Exceptions;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.framework.GraphDef;
+import org.tensorflow.framework.MetaGraphDef;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.SignatureDef;
+import org.tensorflow.framework.TensorInfo;
+import org.tensorflow.framework.TensorShapeProto;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * Converts a saved TensorFlow model into a ranking expression and set of constants.
+ *
+ * @author bratseth
+ */
+public class TensorFlowImporter {
+
+ private final OperationMapper operationMapper = new OperationMapper();
+
+ /**
+ * Imports a saved TensorFlow model from a directory.
+ * The model should be saved as a pbtxt file.
+ * The name of the model is taken as the db/pbtxt file name (not including the file ending).
+ *
+ * @param modelDir the directory containing the TensorFlow model files to import
+ */
+ public ImportResult importModel(String modelDir) {
+ try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
+ return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model);
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e);
+ }
+ }
+
+ public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) {
+ try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
+ MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef());
+ SignatureDef signature = graph.getSignatureDefMap().get(inputSignatureName);
+ ImportResult result = new ImportResult();
+ importInputs(signature.getInputsMap(), result);
+ result.add(new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result)));
+ return result;
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e);
+ }
+ }
+
+ private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) {
+ ImportResult result = new ImportResult();
+ for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
+ importInputs(signatureEntry.getValue().getInputsMap(), result);
+ for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
+ try {
+ ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result);
+ result.add(new RankingExpression(output.getKey(), node));
+ }
+ catch (IllegalArgumentException e) {
+ result.warn("Skipping output '" + output.getValue().getName() + "' of signature '" +
+ signatureEntry.getValue().getMethodName() +
+ "': " + Exceptions.toMessageString(e));
+ }
+ }
+ }
+ return result;
+ }
+
+ private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult result) {
+ inputInfoMap.forEach((key, value) -> result.set(nameOf(value.getName()),
+ importTensorType(value.getTensorShape())));
+ }
+
+ private TensorType importTensorType(TensorShapeProto tensorShape) {
+ TensorType.Builder b = new TensorType.Builder();
+ for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) {
+ int dimensionSize = (int)dimension.getSize();
+ if (dimensionSize >= 0)
+ b.indexed("d" + b.rank(), dimensionSize);
+ else
+ b.indexed("d" + b.rank()); // unbound size
+ }
+ return b.build();
+ }
+
+ private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ return importNode(nameOf(output.getName()), graph, model, result);
+ }
+
+ private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function();
+ return new TensorFunctionNode(function); // wrap top level (only) as an expression
+ }
+
+ /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
+ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ return tensorFunctionOf(tfNode, graph, model, result);
+ }
+
+ private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
+ // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
+ switch (tfNode.getOp().toLowerCase()) {
+ case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add());
+ case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos());
+ case "placeholder" : return operationMapper.placeholder(tfNode, result);
+ case "identity" : return operationMapper.identity(tfNode, model, result);
+ case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result));
+ case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result));
+ default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
+ }
+ }
+
+ private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ return tfNode.getInputList().stream()
+ .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result))
+ .collect(Collectors.toList());
+ }
+
+ private NodeDef getNode(String name, GraphDef graph) {
+ return graph.getNodeList().stream()
+ .filter(node -> node.getName().equals(name))
+ .findFirst()
+ .orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'"));
+ }
+
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ private String nameOf(String name) {
+ return name.split(":")[0];
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
new file mode 100644
index 00000000000..5712da77700
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
@@ -0,0 +1,24 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+/**
+ * A tensor function returning a specific tensor type
+ *
+ * @author bratseth
+ */
+final class TypedTensorFunction {
+
+ private final TensorType type;
+ private final TensorFunction function;
+
+ public TypedTensorFunction(TensorType type, TensorFunction function) {
+ this.type = type;
+ this.function = function;
+ }
+
+ public TensorType type() { return type; }
+ public TensorFunction function() { return function; }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
index 71699b379b2..d366c9bfbe5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
@@ -14,23 +14,23 @@ import java.util.function.*;
/**
* A tensor generating function, whose arguments are determined by a tensor type
- *
+ *
* @author bratseth
*/
public class GeneratorLambdaFunctionNode extends CompositeNode {
private final TensorType type;
private final ExpressionNode generator;
-
+
public GeneratorLambdaFunctionNode(TensorType type, ExpressionNode generator) {
if ( ! type.dimensions().stream().allMatch(d -> d.size().isPresent()))
- throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " +
+ throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " +
"dimensions, but tried to generate " + type);
// TODO: Verify that the function only accesses the given arguments
this.type = type;
this.generator = generator;
}
-
+
@Override
public List<ExpressionNode> children() {
return Collections.singletonList(generator);
@@ -53,8 +53,8 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
public Value evaluate(Context context) {
return generator.evaluate(context);
}
-
- /**
+
+ /**
* Returns this as an operator which converts a list of integers into a double
*/
public IntegerListToDoubleLambda asIntegerListToDoubleOperator() {
@@ -70,7 +70,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
context.put(type.dimensions().get(i).name(), arguments.get(i));
return evaluate(context).asDouble();
}
-
+
@Override
public String toString() {
return GeneratorLambdaFunctionNode.this.toString();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
index 1f8db6e036c..ba765d07094 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
@@ -17,7 +17,7 @@ import java.util.Map;
* @author bratseth
*/
public class SerializationContext {
-
+
/** Expression functions indexed by name */
private final ImmutableMap<String, ExpressionFunction> functions;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index ce21e132980..8af3448ca6f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -21,22 +21,32 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
- @Beta
+@Beta
public class TensorFunctionNode extends CompositeNode {
private final TensorFunction function;
-
+
public TensorFunctionNode(TensorFunction function) {
this.function = function;
}
+ /** Returns the tensor function wrapped by this */
+ public TensorFunction function() { return function; }
+
@Override
public List<ExpressionNode> children() {
return function.functionArguments().stream()
- .map(f -> ((TensorFunctionExpressionNode)f).expression)
+ .map(this::toExpressionNode)
.collect(Collectors.toList());
}
+ private ExpressionNode toExpressionNode(TensorFunction f) {
+ if (f instanceof TensorFunctionExpressionNode)
+ return ((TensorFunctionExpressionNode)f).expression;
+ else
+ return new TensorFunctionNode(f);
+ }
+
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
List<TensorFunction> wrappedChildren = children.stream()
@@ -50,7 +60,7 @@ public class TensorFunctionNode extends CompositeNode {
// Serialize as primitive
return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this));
}
-
+
@Override
public Value evaluate(Context context) {
return new TensorValue(function.evaluate(context));
@@ -59,8 +69,8 @@ public class TensorFunctionNode extends CompositeNode {
public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
return new TensorFunctionExpressionNode(node);
}
-
- /**
+
+ /**
* A tensor function implemented by an expression.
* This allows us to pass expressions as tensor function arguments.
*/
@@ -68,13 +78,13 @@ public class TensorFunctionNode extends CompositeNode {
/** An expression which produces a tensor */
private final ExpressionNode expression;
-
+
public TensorFunctionExpressionNode(ExpressionNode expression) {
this.expression = expression;
}
-
+
@Override
- public List<TensorFunction> functionArguments() {
+ public List<TensorFunction> functionArguments() {
if (expression instanceof CompositeNode)
return ((CompositeNode)expression).children().stream()
.map(TensorFunctionExpressionNode::new)
@@ -108,7 +118,7 @@ public class TensorFunctionNode extends CompositeNode {
public String toString() {
return toString(ExpressionNodeToStringContext.empty);
}
-
+
@Override
public String toString(ToStringContext c) {
if (c instanceof ExpressionNodeToStringContext) {
@@ -121,14 +131,14 @@ public class TensorFunctionNode extends CompositeNode {
}
}
-
+
/** Allows passing serialization context arguments through TensorFunctions */
private static class ExpressionNodeToStringContext implements ToStringContext {
-
+
final SerializationContext context;
final Deque<String> path;
final CompositeNode parent;
-
+
public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null);
public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
new file mode 100644
index 00000000000..a1861a1c981
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
@@ -0,0 +1,89 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A very simple MNIST classifier.
+
+See extensive documentation at
+https://www.tensorflow.org/get_started/mnist/beginners
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+from tensorflow.examples.tutorials.mnist import input_data
+
+import tensorflow as tf
+
+FLAGS = None
+
+
+def main(_):
+ # Import data
+ mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
+
+ # Create the model
+ x = tf.placeholder(tf.float32, [None, 784])
+ W = tf.Variable(tf.zeros([784, 10]))
+ b = tf.Variable(tf.zeros([10]))
+ y = tf.matmul(x, W) + b
+
+ # Define loss and optimizer
+ y_ = tf.placeholder(tf.float32, [None, 10])
+
+ # The raw formulation of cross-entropy,
+ #
+ # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
+ # reduction_indices=[1]))
+ #
+ # can be numerically unstable.
+ #
+ # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
+ # outputs of 'y', and then average across the batch.
+ cross_entropy = tf.reduce_mean(
+ tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
+ train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
+
+ sess = tf.InteractiveSession()
+ tf.global_variables_initializer().run()
+ # Train
+ for _ in range(1000):
+ batch_xs, batch_ys = mnist.train.next_batch(100)
+ sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
+
+ # Test trained model
+ correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+ print(sess.run(accuracy, feed_dict={x: mnist.test.images,
+ y_: mnist.test.labels}))
+
+ # Save the model
+ export_path = "saved"
+ print('Exporting trained model to ', export_path)
+ builder = tf.saved_model.builder.SavedModelBuilder(export_path)
+ signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':x}, outputs = {'y':y})
+ builder.add_meta_graph_and_variables(sess,
+ [tf.saved_model.tag_constants.SERVING],
+ signature_def_map={'serving_default':signature})
+ builder.save(as_text=True)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
+ help='Directory for storing input data')
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt
new file mode 100644
index 00000000000..8100dfd594d
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt
@@ -0,0 +1,5039 @@
+saved_model_schema_version: 1
+meta_graphs {
+ meta_info_def {
+ stripped_op_list {
+ op {
+ name: "Add"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_STRING
+ }
+ }
+ }
+ }
+ op {
+ name: "ApplyGradientDescent"
+ input_arg {
+ name: "var"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "alpha"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "delta"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "out"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ }
+ op {
+ name: "ArgMax"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dimension"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "output_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "output_type"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Assign"
+ input_arg {
+ name: "ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output_ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "validate_shape"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ allows_uninitialized_input: true
+ }
+ op {
+ name: "BroadcastGradientArgs"
+ input_arg {
+ name: "s0"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "s1"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "r0"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "r1"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Cast"
+ input_arg {
+ name: "x"
+ type_attr: "SrcT"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "DstT"
+ }
+ attr {
+ name: "SrcT"
+ type: "type"
+ }
+ attr {
+ name: "DstT"
+ type: "type"
+ }
+ }
+ op {
+ name: "ConcatV2"
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ number_attr: "N"
+ }
+ input_arg {
+ name: "axis"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 2
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Const"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "value"
+ type: "tensor"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ }
+ op {
+ name: "Equal"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type: DT_BOOL
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_QUINT8
+ type: DT_QINT8
+ type: DT_QINT32
+ type: DT_STRING
+ type: DT_BOOL
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ is_commutative: true
+ }
+ op {
+ name: "ExpandDims"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dim"
+ type_attr: "Tdim"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tdim"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Fill"
+ input_arg {
+ name: "dims"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ }
+ op {
+ name: "FloorDiv"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Identity"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ }
+ op {
+ name: "MatMul"
+ input_arg {
+ name: "a"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "b"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "product"
+ type_attr: "T"
+ }
+ attr {
+ name: "transpose_a"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "transpose_b"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Maximum"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_commutative: true
+ }
+ op {
+ name: "Mean"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reduction_indices"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "keep_dims"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "MergeV2Checkpoints"
+ input_arg {
+ name: "checkpoint_prefixes"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "destination_prefix"
+ type: DT_STRING
+ }
+ attr {
+ name: "delete_old_dirs"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ is_stateful: true
+ }
+ op {
+ name: "Mul"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ is_commutative: true
+ }
+ op {
+ name: "NoOp"
+ }
+ op {
+ name: "Pack"
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ number_attr: "N"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "axis"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ }
+ op {
+ name: "Placeholder"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ default_value {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ op {
+ name: "Prod"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reduction_indices"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "keep_dims"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "RealDiv"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Reshape"
+ input_arg {
+ name: "tensor"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "shape"
+ type_attr: "Tshape"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tshape"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "RestoreV2"
+ input_arg {
+ name: "prefix"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "tensor_names"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "shape_and_slices"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "tensors"
+ type_list_attr: "dtypes"
+ }
+ attr {
+ name: "dtypes"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+ }
+ op {
+ name: "SaveV2"
+ input_arg {
+ name: "prefix"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "tensor_names"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "shape_and_slices"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "tensors"
+ type_list_attr: "dtypes"
+ }
+ attr {
+ name: "dtypes"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+ }
+ op {
+ name: "Shape"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "ShardedFilename"
+ input_arg {
+ name: "basename"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "shard"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "num_shards"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "filename"
+ type: DT_STRING
+ }
+ }
+ op {
+ name: "Slice"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "begin"
+ type_attr: "Index"
+ }
+ input_arg {
+ name: "size"
+ type_attr: "Index"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Index"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "SoftmaxCrossEntropyWithLogits"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "labels"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "loss"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprop"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ }
+ op {
+ name: "StringJoin"
+ input_arg {
+ name: "inputs"
+ type: DT_STRING
+ number_attr: "N"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "separator"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ }
+ op {
+ name: "Sub"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Sum"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reduction_indices"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "keep_dims"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Tile"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "multiples"
+ type_attr: "Tmultiples"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tmultiples"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "VariableV2"
+ output_arg {
+ name: "ref"
+ type_attr: "dtype"
+ is_ref: true
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+ }
+ op {
+ name: "ZerosLike"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ }
+ }
+ tags: "serve"
+ tensorflow_version: "1.4.1"
+ tensorflow_git_version: "v1.4.0-19-ga52c8d9b01"
+ }
+ graph_def {
+ node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "zeros"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ float_val: 0.0
+ }
+ }
+ }
+ }
+ node {
+ name: "Variable"
+ op: "VariableV2"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+ }
+ node {
+ name: "Variable/Assign"
+ op: "Assign"
+ input: "Variable"
+ input: "zeros"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "Variable/read"
+ op: "Identity"
+ input: "Variable"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "zeros_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 10
+ }
+ }
+ float_val: 0.0
+ }
+ }
+ }
+ }
+ node {
+ name: "Variable_1"
+ op: "VariableV2"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+ }
+ node {
+ name: "Variable_1/Assign"
+ op: "Assign"
+ input: "Variable_1"
+ input: "zeros_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "Variable_1/read"
+ op: "Identity"
+ input: "Variable_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "MatMul"
+ op: "MatMul"
+ input: "Placeholder"
+ input: "Variable/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "add"
+ op: "Add"
+ input: "MatMul"
+ input: "Variable_1/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Placeholder_1"
+ op: "Placeholder"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Rank"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 2
+ }
+ }
+ }
+ }
+ node {
+ name: "Shape"
+ op: "Shape"
+ input: "add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "Rank_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 2
+ }
+ }
+ }
+ }
+ node {
+ name: "Shape_1"
+ op: "Shape"
+ input: "add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "Sub/y"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Sub"
+ op: "Sub"
+ input: "Rank_1"
+ input: "Sub/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice/begin"
+ op: "Pack"
+ input: "Sub"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
+ }
+ node {
+ name: "Slice/size"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice"
+ op: "Slice"
+ input: "Shape_1"
+ input: "Slice/begin"
+ input: "Slice/size"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "concat/values_0"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: -1
+ }
+ }
+ }
+ }
+ node {
+ name: "concat/axis"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "concat"
+ op: "ConcatV2"
+ input: "concat/values_0"
+ input: "Slice"
+ input: "concat/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Reshape"
+ op: "Reshape"
+ input: "add"
+ input: "concat"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Rank_2"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 2
+ }
+ }
+ }
+ }
+ node {
+ name: "Shape_2"
+ op: "Shape"
+ input: "Placeholder_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "Sub_1/y"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Sub_1"
+ op: "Sub"
+ input: "Rank_2"
+ input: "Sub_1/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice_1/begin"
+ op: "Pack"
+ input: "Sub_1"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
+ }
+ node {
+ name: "Slice_1/size"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice_1"
+ op: "Slice"
+ input: "Shape_2"
+ input: "Slice_1/begin"
+ input: "Slice_1/size"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1/values_0"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: -1
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1/axis"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1"
+ op: "ConcatV2"
+ input: "concat_1/values_0"
+ input: "Slice_1"
+ input: "concat_1/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Reshape_1"
+ op: "Reshape"
+ input: "Placeholder_1"
+ input: "concat_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "SoftmaxCrossEntropyWithLogits"
+ op: "SoftmaxCrossEntropyWithLogits"
+ input: "Reshape"
+ input: "Reshape_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Sub_2/y"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Sub_2"
+ op: "Sub"
+ input: "Rank"
+ input: "Sub_2/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice_2/begin"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice_2/size"
+ op: "Pack"
+ input: "Sub_2"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
+ }
+ node {
+ name: "Slice_2"
+ op: "Slice"
+ input: "Shape"
+ input: "Slice_2/begin"
+ input: "Slice_2/size"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Reshape_2"
+ op: "Reshape"
+ input: "SoftmaxCrossEntropyWithLogits"
+ input: "Slice_2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "Mean"
+ op: "Mean"
+ input: "Reshape_2"
+ input: "Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/Shape"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Const"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Fill"
+ op: "Fill"
+ input: "gradients/Shape"
+ input: "gradients/Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Reshape/shape"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/Fill"
+ input: "gradients/Mean_grad/Reshape/shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Shape"
+ op: "Shape"
+ input: "Reshape_2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Tile"
+ op: "Tile"
+ input: "gradients/Mean_grad/Reshape"
+ input: "gradients/Mean_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tmultiples"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Shape_1"
+ op: "Shape"
+ input: "Reshape_2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Shape_2"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Const"
+ op: "Const"
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Prod"
+ op: "Prod"
+ input: "gradients/Mean_grad/Shape_1"
+ input: "gradients/Mean_grad/Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Const_1"
+ op: "Const"
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Prod_1"
+ op: "Prod"
+ input: "gradients/Mean_grad/Shape_2"
+ input: "gradients/Mean_grad/Const_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Maximum/y"
+ op: "Const"
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Maximum"
+ op: "Maximum"
+ input: "gradients/Mean_grad/Prod_1"
+ input: "gradients/Mean_grad/Maximum/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/floordiv"
+ op: "FloorDiv"
+ input: "gradients/Mean_grad/Prod"
+ input: "gradients/Mean_grad/Maximum"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Cast"
+ op: "Cast"
+ input: "gradients/Mean_grad/floordiv"
+ attr {
+ key: "DstT"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "SrcT"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/truediv"
+ op: "RealDiv"
+ input: "gradients/Mean_grad/Tile"
+ input: "gradients/Mean_grad/Cast"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Reshape_2_grad/Shape"
+ op: "Shape"
+ input: "SoftmaxCrossEntropyWithLogits"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/Reshape_2_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/Mean_grad/truediv"
+ input: "gradients/Reshape_2_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/zeros_like"
+ op: "ZerosLike"
+ input: "SoftmaxCrossEntropyWithLogits:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: -1
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims"
+ op: "ExpandDims"
+ input: "gradients/Reshape_2_grad/Reshape"
+ input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tdim"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul"
+ op: "Mul"
+ input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims"
+ input: "SoftmaxCrossEntropyWithLogits:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Reshape_grad/Shape"
+ op: "Shape"
+ input: "add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/Reshape_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul"
+ input: "gradients/Reshape_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Shape"
+ op: "Shape"
+ input: "MatMul"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Shape_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 10
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/BroadcastGradientArgs"
+ op: "BroadcastGradientArgs"
+ input: "gradients/add_grad/Shape"
+ input: "gradients/add_grad/Shape_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Sum"
+ op: "Sum"
+ input: "gradients/Reshape_grad/Reshape"
+ input: "gradients/add_grad/BroadcastGradientArgs"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/add_grad/Sum"
+ input: "gradients/add_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Sum_1"
+ op: "Sum"
+ input: "gradients/Reshape_grad/Reshape"
+ input: "gradients/add_grad/BroadcastGradientArgs:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Reshape_1"
+ op: "Reshape"
+ input: "gradients/add_grad/Sum_1"
+ input: "gradients/add_grad/Shape_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/tuple/group_deps"
+ op: "NoOp"
+ input: "^gradients/add_grad/Reshape"
+ input: "^gradients/add_grad/Reshape_1"
+ }
+ node {
+ name: "gradients/add_grad/tuple/control_dependency"
+ op: "Identity"
+ input: "gradients/add_grad/Reshape"
+ input: "^gradients/add_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/add_grad/Reshape"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/tuple/control_dependency_1"
+ op: "Identity"
+ input: "gradients/add_grad/Reshape_1"
+ input: "^gradients/add_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/add_grad/Reshape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/MatMul_grad/MatMul"
+ op: "MatMul"
+ input: "gradients/add_grad/tuple/control_dependency"
+ input: "Variable/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "gradients/MatMul_grad/MatMul_1"
+ op: "MatMul"
+ input: "Placeholder"
+ input: "gradients/add_grad/tuple/control_dependency"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/MatMul_grad/tuple/group_deps"
+ op: "NoOp"
+ input: "^gradients/MatMul_grad/MatMul"
+ input: "^gradients/MatMul_grad/MatMul_1"
+ }
+ node {
+ name: "gradients/MatMul_grad/tuple/control_dependency"
+ op: "Identity"
+ input: "gradients/MatMul_grad/MatMul"
+ input: "^gradients/MatMul_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/MatMul_grad/MatMul"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/MatMul_grad/tuple/control_dependency_1"
+ op: "Identity"
+ input: "gradients/MatMul_grad/MatMul_1"
+ input: "^gradients/MatMul_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/MatMul_grad/MatMul_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "GradientDescent/learning_rate"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.5
+ }
+ }
+ }
+ }
+ node {
+ name: "GradientDescent/update_Variable/ApplyGradientDescent"
+ op: "ApplyGradientDescent"
+ input: "Variable"
+ input: "GradientDescent/learning_rate"
+ input: "gradients/MatMul_grad/tuple/control_dependency_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "GradientDescent/update_Variable_1/ApplyGradientDescent"
+ op: "ApplyGradientDescent"
+ input: "Variable_1"
+ input: "GradientDescent/learning_rate"
+ input: "gradients/add_grad/tuple/control_dependency_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "GradientDescent"
+ op: "NoOp"
+ input: "^GradientDescent/update_Variable/ApplyGradientDescent"
+ input: "^GradientDescent/update_Variable_1/ApplyGradientDescent"
+ }
+ node {
+ name: "init"
+ op: "NoOp"
+ input: "^Variable/Assign"
+ input: "^Variable_1/Assign"
+ }
+ node {
+ name: "ArgMax/dimension"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "ArgMax"
+ op: "ArgMax"
+ input: "add"
+ input: "ArgMax/dimension"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_type"
+ value {
+ type: DT_INT64
+ }
+ }
+ }
+ node {
+ name: "ArgMax_1/dimension"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "ArgMax_1"
+ op: "ArgMax"
+ input: "Placeholder_1"
+ input: "ArgMax_1/dimension"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_type"
+ value {
+ type: DT_INT64
+ }
+ }
+ }
+ node {
+ name: "Equal"
+ op: "Equal"
+ input: "ArgMax"
+ input: "ArgMax_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Cast_1"
+ op: "Cast"
+ input: "Equal"
+ attr {
+ key: "DstT"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "SrcT"
+ value {
+ type: DT_BOOL
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "Mean_1"
+ op: "Mean"
+ input: "Cast_1"
+ input: "Const_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "save/Const"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ }
+ string_val: "model"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/StringJoin/inputs_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ }
+ string_val: "_temp_6ca9fa5171ed4237a2fbcc27277e2864/part"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/StringJoin"
+ op: "StringJoin"
+ input: "save/Const"
+ input: "save/StringJoin/inputs_1"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "separator"
+ value {
+ s: ""
+ }
+ }
+ }
+ node {
+ name: "save/num_shards"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "save/ShardedFilename/shard"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "save/ShardedFilename"
+ op: "ShardedFilename"
+ input: "save/StringJoin"
+ input: "save/ShardedFilename/shard"
+ input: "save/num_shards"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "save/SaveV2/tensor_names"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ string_val: "Variable"
+ string_val: "Variable_1"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/SaveV2/shape_and_slices"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ string_val: ""
+ string_val: ""
+ }
+ }
+ }
+ }
+ node {
+ name: "save/SaveV2"
+ op: "SaveV2"
+ input: "save/ShardedFilename"
+ input: "save/SaveV2/tensor_names"
+ input: "save/SaveV2/shape_and_slices"
+ input: "Variable"
+ input: "Variable_1"
+ attr {
+ key: "dtypes"
+ value {
+ list {
+ type: DT_FLOAT
+ type: DT_FLOAT
+ }
+ }
+ }
+ }
+ node {
+ name: "save/control_dependency"
+ op: "Identity"
+ input: "save/ShardedFilename"
+ input: "^save/SaveV2"
+ attr {
+ key: "T"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@save/ShardedFilename"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "save/MergeV2Checkpoints/checkpoint_prefixes"
+ op: "Pack"
+ input: "save/ShardedFilename"
+ input: "^save/control_dependency"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
+ }
+ node {
+ name: "save/MergeV2Checkpoints"
+ op: "MergeV2Checkpoints"
+ input: "save/MergeV2Checkpoints/checkpoint_prefixes"
+ input: "save/Const"
+ attr {
+ key: "delete_old_dirs"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "save/Identity"
+ op: "Identity"
+ input: "save/Const"
+ input: "^save/control_dependency"
+ input: "^save/MergeV2Checkpoints"
+ attr {
+ key: "T"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2/tensor_names"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ string_val: "Variable"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2/shape_and_slices"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ string_val: ""
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2"
+ op: "RestoreV2"
+ input: "save/Const"
+ input: "save/RestoreV2/tensor_names"
+ input: "save/RestoreV2/shape_and_slices"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtypes"
+ value {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ }
+ node {
+ name: "save/Assign"
+ op: "Assign"
+ input: "Variable"
+ input: "save/RestoreV2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2_1/tensor_names"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ string_val: "Variable_1"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2_1/shape_and_slices"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ string_val: ""
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2_1"
+ op: "RestoreV2"
+ input: "save/Const"
+ input: "save/RestoreV2_1/tensor_names"
+ input: "save/RestoreV2_1/shape_and_slices"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtypes"
+ value {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ }
+ node {
+ name: "save/Assign_1"
+ op: "Assign"
+ input: "Variable_1"
+ input: "save/RestoreV2_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "save/restore_shard"
+ op: "NoOp"
+ input: "^save/Assign"
+ input: "^save/Assign_1"
+ }
+ node {
+ name: "save/restore_all"
+ op: "NoOp"
+ input: "^save/restore_shard"
+ }
+ versions {
+ producer: 24
+ }
+ }
+ saver_def {
+ filename_tensor_name: "save/Const:0"
+ save_tensor_name: "save/Identity:0"
+ restore_op_name: "save/restore_all"
+ max_to_keep: 5
+ sharded: true
+ keep_checkpoint_every_n_hours: 10000.0
+ version: V2
+ }
+ collection_def {
+ key: "train_op"
+ value {
+ node_list {
+ value: "GradientDescent"
+ }
+ }
+ }
+ collection_def {
+ key: "trainable_variables"
+ value {
+ bytes_list {
+ value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0"
+ value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0"
+ }
+ }
+ }
+ collection_def {
+ key: "variables"
+ value {
+ bytes_list {
+ value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0"
+ value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0"
+ }
+ }
+ }
+ signature_def {
+ key: "serving_default"
+ value {
+ inputs {
+ key: "x"
+ value {
+ name: "Placeholder:0"
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ outputs {
+ key: "y"
+ value {
+ name: "add:0"
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ method_name: "tensorflow/serving/predict"
+ }
+ }
+}
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
new file mode 100644
index 00000000000..8474aa0a04c
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index
new file mode 100644
index 00000000000..cfcdac20409
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index
Binary files differ
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 82e5d0cfe5b..3aa2d144f1f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.rule.*;
-import com.yahoo.tensor.Tensor;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import org.junit.Test;
+
import static org.junit.Assert.assertEquals;
/**
@@ -83,7 +88,7 @@ public class EvaluationTestCase {
tester.assertEvaluates(0, "sin(0)");
tester.assertEvaluates(1, "cos(0)");
tester.assertEvaluates(8, "pow(4/2,min(cos(0)*3,5))");
-
+
// Random feature (which is also a tensor function) (We expect to be able to parse it and look up a zero)
tester.assertEvaluates(0, "random(1)");
tester.assertEvaluates(0, "random(foo)");
@@ -152,7 +157,7 @@ public class EvaluationTestCase {
"tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
"!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }");
-
+
// -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence)
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }");
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
index ee2b1c147e3..ba0db4de5e1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
@@ -34,7 +34,7 @@ public class EvaluationTester {
}
// TODO: Test both bound and unbound indexed
- public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors,
+ public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors,
String ... tensorArgumentStrings) {
MapContext context = defaultContext.thawedCopy();
int argumentIndex = 0;
@@ -46,7 +46,7 @@ public class EvaluationTester {
argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString);
context.put("tensor" + (argumentIndex++), new TensorValue(argument));
}
- return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context,
+ return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context,
mappedTensors ? "Mapped tensors" : "Indexed tensors");
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
new file mode 100644
index 00000000000..863479f3531
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
@@ -0,0 +1,114 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+
+import java.nio.FloatBuffer;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * @author bratseth
+ */
+public class Mnist_SoftmaxTestCase {
+
+ @Test
+ public void testImporting() {
+ String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
+ ImportResult result = new TensorFlowImporter().importModel(modelDir);
+
+ // Check logged messages
+ result.warnings().forEach(System.err::println);
+ assertEquals(0, result.warnings().size());
+
+ // Check arguments
+ assertEquals(1, result.arguments().size());
+ TensorType argument0 = result.arguments().get("Placeholder");
+ assertNotNull(argument0);
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
+
+ // Check constants
+ assertEquals(2, result.constants().size());
+
+ Tensor constant0 = result.constants().get("Variable");
+ assertNotNull(constant0);
+ assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
+ constant0.type());
+ assertEquals(7840, constant0.size());
+
+ Tensor constant1 = result.constants().get("Variable_1");
+ assertNotNull(constant1);
+ assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
+ constant1.type());
+ assertEquals(10, constant1.size());
+
+ // Check resulting Vespa expression
+ assertEquals(1, result.expressions().size());
+ assertEquals("y", result.expressions().get(0).getName());
+ assertEquals("" +
+ "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
+ "rename(constant(Variable_1), d0, d1), " +
+ "f(a,b)(a + b))",
+ toNonPrimitiveString(result.expressions().get(0)));
+
+ // Test execution
+ String signatureName = "serving_default";
+
+ assertEqualResult(modelDir, signatureName, "Variable/read");
+ assertEqualResult(modelDir, signatureName, "Variable_1/read");
+ // TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder");
+ assertEqualResult(modelDir, signatureName, "MatMul");
+ assertEqualResult(modelDir, signatureName, "add");
+ }
+
+ private void assertEqualResult(String modelDir, String signatureName, String operationName) {
+ ImportResult result = new TensorFlowImporter().importNode(modelDir, signatureName, operationName);
+
+ Tensor tfResult = tensorFlowExecute(modelDir, operationName);
+ Context context = contextFrom(result);
+ Tensor placeholder = placeholderArgument();
+ context.put("Placeholder", new TensorValue(placeholder));
+ Tensor vespaResult = result.expressions().get(0).evaluate(context).asTensor();
+ assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
+ }
+
+ private Tensor tensorFlowExecute(String modelDir, String operationName) {
+ SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
+ Session.Runner runner = model.session().runner();
+ org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784));
+ runner.feed("Placeholder", placeholder);
+ List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
+ assertEquals(1, results.size());
+ return new TensorConverter().toVespaTensor(results.get(0));
+ }
+
+ private Context contextFrom(ImportResult result) {
+ MapContext context = new MapContext();
+ result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ return context;
+ }
+
+ private String toNonPrimitiveString(RankingExpression expression) {
+ // toString on the wrapping expression will map to primitives, which is harder to read
+ return ((TensorFunctionNode)expression.getRoot()).function().toString();
+ }
+
+ private Tensor placeholderArgument() {
+ int size = 784;
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build());
+ for (int i = 0; i < size; i++)
+ b.cell(0, 0, i);
+ return b.build();
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
index dde9d4bf21e..1960c1fe876 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
@@ -59,7 +59,7 @@ public class TensorConformanceTest {
try {
ObjectMapper mapper = new ObjectMapper();
JsonNode node = mapper.readTree(test);
-
+
if (node.has("num_tests")) {
Assert.assertEquals(node.get("num_tests").asInt(), count);
return true;
@@ -67,7 +67,7 @@ public class TensorConformanceTest {
if (!node.has("expression")) {
return true; // ignore
}
-
+
String expression = node.get("expression").asText();
MapContext context = getInput(node.get("inputs"));
Tensor expect = getTensor(node.get("result").get("expect").asText());