aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-13 15:21:44 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-13 15:21:44 +0100
commit3783a9b21f8ab7ca3700903d9780a9f7374cf0c5 (patch)
treeec003528946a37b9f0aeb49e1b314fdc6601c26e /searchlib/src
parent5b67e6f8f641141f848ad3989156151f9f182441 (diff)
Check agreement between TF and Vespa execution
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java44
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java51
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java23
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java115
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java114
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java9
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py89
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt (renamed from searchlib/src/test/files/integration/tensorflow/model1/saved_model.pbtxt)2484
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001bin0 -> 31400 bytes
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.indexbin0 -> 159 bytes
-rw-r--r--searchlib/src/test/files/integration/tensorflow/model1/variables/variables.data-00000-of-00001bin31400 -> 0 bytes
-rw-r--r--searchlib/src/test/files/integration/tensorflow/model1/variables/variables.indexbin159 -> 0 bytes
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java13
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java114
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java79
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java4
24 files changed, 1767 insertions, 1441 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 785ed78492e..0eeb0a9e630 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Set;
@@ -18,26 +19,30 @@ public abstract class Context implements EvaluationContext {
/**
* <p>Returns the value of a simple variable name.</p>
*
- * @param name The name of the variable whose value to return.
- * @return The value of the named variable.
+ * @param name the name of the variable whose value to return.
+ * @return the value of the named variable.
*/
public abstract Value get(String name);
+ /** Returns a variable as a tensor */
+ @Override
+ public Tensor getTensor(String name) { return get(name).asTensor(); }
+
/**
* <p>Returns the value of a <i>structured variable</i> on the form
* <code>name(argument*)(.output)?</code>, where <i>argument</i> is any
* string. This may be used to implement more advanced variables whose
* values are calculated at runtime from arguments. Supporting this in a
- * context is optional.
- *
+ * context is optional.
+ *
* <p>This default implementation generates a name on the form
* <code>name(argument1, argument2, ...argumentN).output</code>.
* If there are no arguments the parenthesis are omitted.
* If there is no output, the dot is omitted.</p>
*
- * @param name The name of this variable.
- * @param arguments The parsed arguments as given in the textual expression.
- * @param output The name of the value to output (to enable one named
+ * @param name the name of this variable.
+ * @param arguments the parsed arguments as given in the textual expression.
+ * @param output the name of the value to output (to enable one named
* calculation to output several), or null to output the
* "main" (or only) value.
*/
@@ -54,20 +59,20 @@ public abstract class Context implements EvaluationContext {
* context subclasses. This default implementation throws
* UnsupportedOperationException.</p>
*
- * @param index The index of the variable whose value to return.
- * @return The value of the indexed variable.
+ * @param index the index of the variable whose value to return.
+ * @return the value of the indexed variable.
*/
public Value get(int index) {
throw new UnsupportedOperationException(this + " does not support variable lookup by index");
}
/**
- * <p>Lookup by index rather than name directly to a double. This is supported by some optimized
+ * Lookup by index rather than name directly to a double. This is supported by some optimized
* context subclasses. This default implementation throws
- * UnsupportedOperationException.</p>
+ * UnsupportedOperationException.
*
- * @param index The index of the variable whose value to return.
- * @return The value of the indexed variable.
+ * @param index the index of the variable whose value to return.
+ * @return the value of the indexed variable.
*/
public double getDouble(int index) {
throw new UnsupportedOperationException(this + " does not support variable lookup by index");
@@ -81,24 +86,23 @@ public abstract class Context implements EvaluationContext {
}
/**
- * <p>Sets a value to this, or throws an UnsupportedOperationException if
- * this is not supported. This default implementation does the latter.</p> *
+ * Sets a value to this, or throws an UnsupportedOperationException if
+ * this is not supported. This default implementation does the latter.
*
- * @param name The name of the variable to set.
+ * @param name the name of the variable to set.
* @param value the value to set. Ownership of this value is transferred to this - if it is mutable
* (not frozen) it may be modified during execution
- * @since 5.1.5
*/
public void put(String name, Value value) {
throw new UnsupportedOperationException(this + " does not support variable assignment");
}
/**
- * <p>Returns all the names available in this, or throws an
+ * Returns all the names available in this, or throws an
* UnsupportedOperationException if this operation is not supported. This
- * default implementation does the latter.</p>
+ * default implementation does the latter.
*
- * @return The set of all variable names.
+ * @return the set of all variable names.
*/
public Set<String> names() {
throw new UnsupportedOperationException(this + " does not support return a list of its names");
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index ea750295423..2ef4a2ede2f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -3,6 +3,9 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* A value which acts as a double in numerical context.
@@ -16,6 +19,11 @@ public abstract class DoubleCompatibleValue extends Value {
public boolean hasDouble() { return true; }
@Override
+ public Tensor asTensor() {
+ return doubleAsTensor(asDouble());
+ }
+
+ @Override
public Value negate() { return new DoubleValue(-asDouble()); }
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index ac8aba6a617..dad69b31181 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -4,12 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* A string value.
*
* @author bratseth
- * @since 5.1.21
*/
public class StringValue extends Value {
@@ -35,6 +37,11 @@ public class StringValue extends Value {
}
@Override
+ public Tensor asTensor() {
+ return doubleAsTensor(asDouble());
+ }
+
+ @Override
public boolean hasDouble() { return true; }
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 49c3ccb7b01..26c30fe5ed2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -2,14 +2,10 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.google.common.annotations.Beta;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
-import com.yahoo.tensor.TensorType;
-
-import java.util.Collections;
-import java.util.Optional;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
/**
* A Value containing a tensor.
@@ -23,7 +19,7 @@ public class TensorValue extends Value {
/** The tensor value of this */
private final Tensor value;
-
+
public TensorValue(Tensor value) {
this.value = value;
}
@@ -131,7 +127,7 @@ public class TensorValue extends Value {
public Value compare(TruthOperator operator, Value argument) {
return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString())));
}
-
+
private Tensor compareTensor(TruthOperator operator, Tensor argument) {
switch (operator) {
case LARGER: return value.larger(argument);
@@ -152,7 +148,7 @@ public class TensorValue extends Value {
else
return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble())));
}
-
+
private Tensor functionOnTensor(Function function, Tensor argument) {
switch (function) {
case min: return value.min(argument);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index b2ccbe572d0..40d70e0022c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -5,6 +5,8 @@ import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* The result of a ranking expression evaluation.
@@ -25,6 +27,14 @@ public abstract class Value {
return new DoubleValue(asDouble());
}
+ /** Returns this as a tensor value */
+ public abstract Tensor asTensor();
+
+ /** A utility method for wrapping a sdouble in a rank 0 tensor */
+ protected Tensor doubleAsTensor(double value) {
+ return Tensor.Builder.of(TensorType.empty).cell(TensorAddress.of(), value).build();
+ }
+
/** Returns true if this value can return itself as a double, i.e asDoubleValue will return a value and not throw */
public abstract boolean hasDouble();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
new file mode 100644
index 00000000000..b4a9b363ade
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
@@ -0,0 +1,51 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * The result of importing a TensorFlow model into Vespa:
+ * - A list of ranking expressions reproducing the computations of the outputs in the TensorFlow model
+ * - A list of named constant tensors
+ * - A list of expected input tensors, with their tensor type
+ * - A list of warning messages
+ *
+ * @author bratseth
+ */
+// This object can be built incrementally within this package, but is immutable when observed from outside the package
+// TODO: Retain signature structure in ImportResult (input + output-expression bundles)
+public class ImportResult {
+
+ private final List<RankingExpression> expressions = new ArrayList<>();
+ private final Map<String, Tensor> constants = new HashMap<>();
+ private final Map<String, TensorType> arguments = new HashMap<>();
+ private final List<String> warnings = new ArrayList<>();
+
+ void add(RankingExpression expression) { expressions.add(expression); }
+ void set(String name, Tensor constant) { constants.put(name, constant); }
+ void set(String name, TensorType argument) { arguments.put(name, argument); }
+ void warn(String warning) { warnings.add(warning); }
+
+ /** Returns an immutable list of the expressions of this */
+ public List<RankingExpression> expressions() { return Collections.unmodifiableList(expressions); }
+
+ /** Returns an immutable map of the constants of this */
+ public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); }
+
+ /** Returns an immutable map of the arguments of this */
+ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
+
+ /** Returns an immutable list, in natural sort order of the warnings generated while importing this */
+ public List<String> warnings() {
+ return warnings.stream().sorted().collect(Collectors.toList());
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java
deleted file mode 100644
index 235771bfa9c..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java
+++ /dev/null
@@ -1,23 +0,0 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.tensor.Tensor;
-
-/**
- * A tensor with a name
- *
- * @author bratseth
- */
-public class NamedTensor {
-
- private final String name;
- private final Tensor tensor;
-
- public NamedTensor(String name, Tensor tensor) {
- this.name = name;
- this.tensor = tensor;
- }
-
- public String name() { return name; }
- public Tensor tensor() { return tensor; }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index 183cfabbd87..e7f7b5ef2f4 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -10,13 +10,14 @@ import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Matmul;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;
+import com.yahoo.tensor.functions.TensorFunction;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;
+import java.util.ArrayList;
import java.util.List;
-import java.util.Map;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
@@ -45,16 +46,35 @@ class OperationMapper {
private TensorConverter tensorConverter = new TensorConverter();
TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) {
- // Note that this generalizes the corresponding TF function as it does not verify that the tensor
- // types are the same, with the assumption that this already happened on the TF side
- // (and if not, this should do the right thing anyway)
ensureArguments(2, arguments, "join");
TypedTensorFunction a = arguments.get(0);
TypedTensorFunction b = arguments.get(1);
+ if (a.type().rank() < b.type().rank())
+ throw new IllegalArgumentException("Attempt to join " + a.type() + " and " + b.type() + ", " +
+ "but this is not supported when the second argument has a higher rank");
+
+ TensorFunction bFunction = b.function();
+
+ if (a.type().rank() > b.type().rank()) {
+ // Well now we have entered the wonderful world of "broadcasting"
+ // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ // I'm not able to extract from that any unambiguous specification of which dimensions
+ // should be "stretched" when the tensor do not have the same number of dimensions.
+ // From trying this with TensorFlow it appears that the second tensor is matched to the
+ // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
+ // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).
+ List<String> renameFrom = new ArrayList<>();
+ List<String> renameTo = new ArrayList<>();
+ int sizeDifference = a.type().rank() - b.type().rank();
+ for (int i = 0; i < b.type().rank(); i++) {
+ renameFrom.add(b.type().dimensions().get(i).name());
+ renameTo.add("d" + (sizeDifference + i));
+ }
+ bFunction = new Rename(bFunction, renameFrom, renameTo);
+ }
- TensorType resultType = Join.outputType(a.type(), b.type());
- Join function = new Join(a.function(), b.function(), doubleFunction);
- return new TypedTensorFunction(resultType, function);
+ Join function = new Join(a.function(), bFunction, doubleFunction);
+ return new TypedTensorFunction(a.type(), function); // output type is a type by TF definition and a.rank>=b.rank
}
TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) {
@@ -66,35 +86,37 @@ class OperationMapper {
return new TypedTensorFunction(resultType, function);
}
- TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs, SavedModelBundle model,
- List<NamedTensor> constants) {
- String name;
- TensorType type;
- if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model
- if (tfNode.getInputList().size() != 1)
- throw new IllegalArgumentException("A Variable/read node must have one input but has " +
- tfNode.getInputList().size());
- name = tfNode.getInput(0);
- AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
- if (shapes == null)
- throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape");
- Session.Runner fetched = model.session().runner().fetch(name);
- List<org.tensorflow.Tensor<?>> result = fetched.run();
- if ( result.size() != 1)
- throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + result.size());
- Tensor constant = tensorConverter.toVespaTensor(result.get(0));
- constants.add(new NamedTensor(name, constant));
- return new TypedTensorFunction(constant.type(),
- new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
- }
- else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name
- name = tfNode.getName();
- type = inputs.get(name);
- if (type == null)
- throw new IllegalArgumentException("An identity operation node is referencing input '" + name +
- "', but there is no such input");
- return new TypedTensorFunction(type, new VariableTensor(name));
- }
+ TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) {
+ String name = tfNode.getName();
+ TensorType type = result.arguments().get(name);
+ if (type == null)
+ throw new IllegalArgumentException("An placeholder operation node is referencing input '" + name +
+ "', but there is no such input");
+ // Included literally in the expression and so must be produced by a separate macro in the rank profile
+ return new TypedTensorFunction(type, new VariableTensor(name));
+ }
+
+ TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) {
+ if ( ! tfNode.getName().endsWith("/read"))
+ throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " +
+ "nodes are only supported when reading variables");
+ if (tfNode.getInputList().size() != 1)
+ throw new IllegalArgumentException("A Variable/read node must have one input but has " +
+ tfNode.getInputList().size());
+
+ String name = tfNode.getInput(0);
+ AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
+ if (shapes == null)
+ throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape");
+ Session.Runner fetched = model.session().runner().fetch(name);
+ List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
+ if ( importedTensors.size() != 1)
+ throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " +
+ importedTensors.size());
+ Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0));
+ result.set(name, constant);
+ return new TypedTensorFunction(constant.type(),
+ new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
}
TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
@@ -106,21 +128,18 @@ class OperationMapper {
if (a.type().rank() != b.type().rank())
throw new IllegalArgumentException("Tensors in matmul must have the same rank");
- // Let the second-to-last dimension of the second tensor be the same as the last dimension of the first
- // and the last dimension of the second argument be not present in the first argument, while leaving the
+ String afterLastDim = "d" + (a.type().rank() + 1);
+ // Let the first dimension of the second tensor be the same as the second dimension of the first
+ // and the second dimension of the second argument be not present in the first argument, while leaving the
// rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication.
- // TODO: Check if transpose_a or transpose_b is set and rename differently accordingly
-
- String beforeLastDim = "d" + (a.type().rank() - 1);
- String lastDim = "d" + a.type().rank();
- String afterLastDim = "d" + (a.type().rank() + 1);
+ // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly
- Rename renamedB = new Rename(b.function(), ImmutableList.of(beforeLastDim, lastDim),
- ImmutableList.of(lastDim, afterLastDim));
- Matmul matmul = new Matmul(a.function(), renamedB, lastDim);
- return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), lastDim),
- new Rename(matmul, afterLastDim, lastDim));
+ Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"),
+ ImmutableList.of("d1", afterLastDim));
+ Matmul matmul = new Matmul(a.function(), renamedB, "d1");
+ return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"),
+ new Rename(matmul, afterLastDim, "d1"));
}
TypedTensorFunction softmax(List<TypedTensorFunction> arguments) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
index a74445008b7..df43225c333 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
@@ -24,8 +24,10 @@ public class TensorConverter {
private TensorType toVespaTensorType(long[] shape) {
TensorType.Builder b = new TensorType.Builder();
int dimensionIndex = 0;
- for (long dimensionSize : shape)
- b.indexed("d" + (dimensionIndex++), (int)dimensionSize);
+ for (long dimensionSize : shape) {
+ if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
+ b.indexed("d" + (dimensionIndex++), (int) dimensionSize);
+ }
return b.build();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 51f1e444e70..33523244129 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -16,11 +16,8 @@ import org.tensorflow.framework.TensorInfo;
import org.tensorflow.framework.TensorShapeProto;
import java.io.IOException;
-import java.util.ArrayList;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.logging.Level;
import java.util.stream.Collectors;
/**
@@ -35,104 +32,100 @@ public class TensorFlowImporter {
/**
* Imports a saved TensorFlow model from a directory.
* The model should be saved as a pbtxt file.
- * The name of the model is taken at the pbtxt file name (not including the .pbtxt ending).
+ * The name of the model is taken as the db/pbtxt file name (not including the file ending).
*
* @param modelDir the directory containing the TensorFlow model files to import
- * @param constants any constant tensors imported from the TensorFlow model and referenced in the returned expressions
- * @param logger a receiver of any messages generated by the import process
- * @return the ranking expressions resulting from importing this TenorFlow model
*/
- public List<RankingExpression> importModel(String modelDir, List<NamedTensor> constants, MessageLogger logger) {
- try {
- SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
- return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model, constants, logger);
+ public ImportResult importModel(String modelDir) {
+ try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
+ return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model);
}
catch (IOException e) {
- throw new IllegalArgumentException("Could not open TensorFlow model directory '" + modelDir + "'", e);
+ throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e);
}
+ }
+ public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) {
+ try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
+ MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef());
+ SignatureDef signature = graph.getSignatureDefMap().get(inputSignatureName);
+ ImportResult result = new ImportResult();
+ importInputs(signature.getInputsMap(), result);
+ result.add(new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result)));
+ return result;
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e);
+ }
}
- private List<RankingExpression> importGraph(MetaGraphDef graph, SavedModelBundle model,
- List<NamedTensor> constants, MessageLogger logger) {
- List<RankingExpression> expressions = new ArrayList<>();
+ private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) {
+ ImportResult result = new ImportResult();
for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
- Map<String, TensorType> inputs = importInputs(signatureEntry.getValue().getInputsMap());
+ importInputs(signatureEntry.getValue().getInputsMap(), result);
for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
try {
- ExpressionNode result = importOutput(output.getValue(),
- inputs,
- graph.getGraphDef(),
- model,
- constants);
- expressions.add(new RankingExpression(output.getKey(), result));
+ ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result);
+ result.add(new RankingExpression(output.getKey(), node));
}
catch (IllegalArgumentException e) {
- logger.log(Level.INFO, "Skipping output '" + output.getValue().getName() + "' of signature '" +
- signatureEntry.getValue().getMethodName() +
- "': " + Exceptions.toMessageString(e));
+ result.warn("Skipping output '" + output.getValue().getName() + "' of signature '" +
+ signatureEntry.getValue().getMethodName() +
+ "': " + Exceptions.toMessageString(e));
}
}
}
- return expressions;
+ return result;
}
- private Map<String, TensorType> importInputs(Map<String, TensorInfo> inputInfoMap) {
- Map<String, TensorType> inputs = new HashMap<>();
- inputInfoMap.forEach((key, value) -> inputs.put(nameOf(value.getName()),
+ private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult result) {
+ inputInfoMap.forEach((key, value) -> result.set(nameOf(value.getName()),
importTensorType(value.getTensorShape())));
- return inputs;
}
private TensorType importTensorType(TensorShapeProto tensorShape) {
TensorType.Builder b = new TensorType.Builder();
- for (int i = 0; i < tensorShape.getDimCount(); i++) {
- int dimensionSize = (int) tensorShape.getDim(i).getSize();
+ for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) {
+ int dimensionSize = (int)dimension.getSize();
if (dimensionSize >= 0)
- b.indexed("d" + i, dimensionSize);
+ b.indexed("d" + b.rank(), dimensionSize);
else
- b.indexed("d" + i); // unbound size
+ b.indexed("d" + b.rank()); // unbound size
}
return b.build();
}
- private ExpressionNode importOutput(TensorInfo output, Map<String, TensorType> inputs, GraphDef graph,
- SavedModelBundle model, List<NamedTensor> constants) {
- NodeDef node = getNode(nameOf(output.getName()), graph);
- TensorFunction function = importNode(node, inputs, graph, model, constants).function();
+ private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ return importNode(nameOf(output.getName()), graph, model, result);
+ }
+
+ private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function();
return new TensorFunctionNode(function); // wrap top level (only) as an expression
}
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
- private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph,
- SavedModelBundle model, List<NamedTensor> constants) {
- return tensorFunctionOf(tfNode, inputs, graph, model, constants);
+ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ return tensorFunctionOf(tfNode, graph, model, result);
}
- private TypedTensorFunction tensorFunctionOf(NodeDef tfNode,
- Map<String, TensorType> inputs,
- GraphDef graph,
- SavedModelBundle model,
- List<NamedTensor> constants) {
+ private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
// Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
// TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
switch (tfNode.getOp().toLowerCase()) {
- case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, model, constants), ScalarFunctions.add());
- case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, model, constants), ScalarFunctions.acos());
- case "identity" : return operationMapper.identity(tfNode, inputs, model, constants);
- case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, model, constants));
- case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, model, constants));
+ case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add());
+ case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos());
+ case "placeholder" : return operationMapper.placeholder(tfNode, result);
+ case "identity" : return operationMapper.identity(tfNode, model, result);
+ case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result));
+ case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result));
default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
}
}
- private List<TypedTensorFunction> importArguments(NodeDef tfNode,
- Map<String, TensorType> inputs,
- GraphDef graph,
- SavedModelBundle model,
- List<NamedTensor> constants) {
+ private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
return tfNode.getInputList().stream()
- .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, model, constants))
+ .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result))
.collect(Collectors.toList());
}
@@ -151,11 +144,4 @@ public class TensorFlowImporter {
return name.split(":")[0];
}
- /** An interface which can be implemented to receive messages emitted during import */
- public interface MessageLogger {
-
- void log(Level level, String message);
-
- }
-
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
index 234d620d02f..5712da77700 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
@@ -3,9 +3,9 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-/**
+/**
* A tensor function returning a specific tensor type
- *
+ *
* @author bratseth
*/
final class TypedTensorFunction {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
index 71699b379b2..d366c9bfbe5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
@@ -14,23 +14,23 @@ import java.util.function.*;
/**
* A tensor generating function, whose arguments are determined by a tensor type
- *
+ *
* @author bratseth
*/
public class GeneratorLambdaFunctionNode extends CompositeNode {
private final TensorType type;
private final ExpressionNode generator;
-
+
public GeneratorLambdaFunctionNode(TensorType type, ExpressionNode generator) {
if ( ! type.dimensions().stream().allMatch(d -> d.size().isPresent()))
- throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " +
+ throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " +
"dimensions, but tried to generate " + type);
// TODO: Verify that the function only accesses the given arguments
this.type = type;
this.generator = generator;
}
-
+
@Override
public List<ExpressionNode> children() {
return Collections.singletonList(generator);
@@ -53,8 +53,8 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
public Value evaluate(Context context) {
return generator.evaluate(context);
}
-
- /**
+
+ /**
* Returns this as an operator which converts a list of integers into a double
*/
public IntegerListToDoubleLambda asIntegerListToDoubleOperator() {
@@ -70,7 +70,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
context.put(type.dimensions().get(i).name(), arguments.get(i));
return evaluate(context).asDouble();
}
-
+
@Override
public String toString() {
return GeneratorLambdaFunctionNode.this.toString();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index d1f4cbddf6e..8af3448ca6f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -36,10 +36,17 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public List<ExpressionNode> children() {
return function.functionArguments().stream()
- .map(f -> ((TensorFunctionExpressionNode)f).expression)
+ .map(this::toExpressionNode)
.collect(Collectors.toList());
}
+ private ExpressionNode toExpressionNode(TensorFunction f) {
+ if (f instanceof TensorFunctionExpressionNode)
+ return ((TensorFunctionExpressionNode)f).expression;
+ else
+ return new TensorFunctionNode(f);
+ }
+
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
List<TensorFunction> wrappedChildren = children.stream()
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
new file mode 100644
index 00000000000..a1861a1c981
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
@@ -0,0 +1,89 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A very simple MNIST classifier.
+
+See extensive documentation at
+https://www.tensorflow.org/get_started/mnist/beginners
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+from tensorflow.examples.tutorials.mnist import input_data
+
+import tensorflow as tf
+
+FLAGS = None
+
+
+def main(_):
+ # Import data
+ mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
+
+ # Create the model
+ x = tf.placeholder(tf.float32, [None, 784])
+ W = tf.Variable(tf.zeros([784, 10]))
+ b = tf.Variable(tf.zeros([10]))
+ y = tf.matmul(x, W) + b
+
+ # Define loss and optimizer
+ y_ = tf.placeholder(tf.float32, [None, 10])
+
+ # The raw formulation of cross-entropy,
+ #
+ # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
+ # reduction_indices=[1]))
+ #
+ # can be numerically unstable.
+ #
+ # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
+ # outputs of 'y', and then average across the batch.
+ cross_entropy = tf.reduce_mean(
+ tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
+ train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
+
+ sess = tf.InteractiveSession()
+ tf.global_variables_initializer().run()
+ # Train
+ for _ in range(1000):
+ batch_xs, batch_ys = mnist.train.next_batch(100)
+ sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
+
+ # Test trained model
+ correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+ print(sess.run(accuracy, feed_dict={x: mnist.test.images,
+ y_: mnist.test.labels}))
+
+ # Save the model
+ export_path = "saved"
+ print('Exporting trained model to ', export_path)
+ builder = tf.saved_model.builder.SavedModelBuilder(export_path)
+ signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':x}, outputs = {'y':y})
+ builder.add_meta_graph_and_variables(sess,
+ [tf.saved_model.tag_constants.SERVING],
+ signature_def_map={'serving_default':signature})
+ builder.save(as_text=True)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
+ help='Directory for storing input data')
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/searchlib/src/test/files/integration/tensorflow/model1/saved_model.pbtxt b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt
index e01688669a1..8100dfd594d 100644
--- a/searchlib/src/test/files/integration/tensorflow/model1/saved_model.pbtxt
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt
@@ -237,6 +237,45 @@ meta_graphs {
}
}
op {
+ name: "ConcatV2"
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ number_attr: "N"
+ }
+ input_arg {
+ name: "axis"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 2
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
name: "Const"
output_arg {
name: "output"
@@ -291,14 +330,14 @@ meta_graphs {
is_commutative: true
}
op {
- name: "Fill"
+ name: "ExpandDims"
input_arg {
- name: "dims"
- type: DT_INT32
+ name: "input"
+ type_attr: "T"
}
input_arg {
- name: "value"
- type_attr: "T"
+ name: "dim"
+ type_attr: "Tdim"
}
output_arg {
name: "output"
@@ -308,48 +347,28 @@ meta_graphs {
name: "T"
type: "type"
}
- }
- op {
- name: "HashTableV2"
- output_arg {
- name: "table_handle"
- type: DT_RESOURCE
- }
- attr {
- name: "container"
- type: "string"
- default_value {
- s: ""
- }
- }
attr {
- name: "shared_name"
- type: "string"
+ name: "Tdim"
+ type: "type"
default_value {
- s: ""
+ type: DT_INT32
}
- }
- attr {
- name: "use_node_name_sharing"
- type: "bool"
- default_value {
- b: false
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
}
}
- attr {
- name: "key_dtype"
- type: "type"
- }
- attr {
- name: "value_dtype"
- type: "type"
- }
- is_stateful: true
}
op {
- name: "Identity"
+ name: "Fill"
input_arg {
- name: "input"
+ name: "dims"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "value"
type_attr: "T"
}
output_arg {
@@ -362,37 +381,17 @@ meta_graphs {
}
}
op {
- name: "InitializeTableV2"
- input_arg {
- name: "table_handle"
- type: DT_RESOURCE
- }
- input_arg {
- name: "keys"
- type_attr: "Tkey"
- }
+ name: "FloorDiv"
input_arg {
- name: "values"
- type_attr: "Tval"
- }
- attr {
- name: "Tkey"
- type: "type"
- }
- attr {
- name: "Tval"
- type: "type"
+ name: "x"
+ type_attr: "T"
}
- is_stateful: true
- }
- op {
- name: "Log"
input_arg {
- name: "x"
+ name: "y"
type_attr: "T"
}
output_arg {
- name: "y"
+ name: "z"
type_attr: "T"
}
attr {
@@ -403,6 +402,12 @@ meta_graphs {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@@ -410,32 +415,19 @@ meta_graphs {
}
}
op {
- name: "LookupTableFindV2"
- input_arg {
- name: "table_handle"
- type: DT_RESOURCE
- }
- input_arg {
- name: "keys"
- type_attr: "Tin"
- }
+ name: "Identity"
input_arg {
- name: "default_value"
- type_attr: "Tout"
+ name: "input"
+ type_attr: "T"
}
output_arg {
- name: "values"
- type_attr: "Tout"
- }
- attr {
- name: "Tin"
- type: "type"
+ name: "output"
+ type_attr: "T"
}
attr {
- name: "Tout"
+ name: "T"
type: "type"
}
- is_stateful: true
}
op {
name: "MatMul"
@@ -481,6 +473,35 @@ meta_graphs {
}
}
op {
+ name: "Maximum"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_commutative: true
+ }
+ op {
name: "Mean"
input_arg {
name: "input"
@@ -592,32 +613,6 @@ meta_graphs {
is_commutative: true
}
op {
- name: "Neg"
- input_arg {
- name: "x"
- type_attr: "T"
- }
- output_arg {
- name: "y"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_HALF
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- }
- }
- }
- }
- op {
name: "NoOp"
}
op {
@@ -650,88 +645,6 @@ meta_graphs {
}
}
op {
- name: "ParseExample"
- input_arg {
- name: "serialized"
- type: DT_STRING
- }
- input_arg {
- name: "names"
- type: DT_STRING
- }
- input_arg {
- name: "sparse_keys"
- type: DT_STRING
- number_attr: "Nsparse"
- }
- input_arg {
- name: "dense_keys"
- type: DT_STRING
- number_attr: "Ndense"
- }
- input_arg {
- name: "dense_defaults"
- type_list_attr: "Tdense"
- }
- output_arg {
- name: "sparse_indices"
- type: DT_INT64
- number_attr: "Nsparse"
- }
- output_arg {
- name: "sparse_values"
- type_list_attr: "sparse_types"
- }
- output_arg {
- name: "sparse_shapes"
- type: DT_INT64
- number_attr: "Nsparse"
- }
- output_arg {
- name: "dense_values"
- type_list_attr: "Tdense"
- }
- attr {
- name: "Nsparse"
- type: "int"
- has_minimum: true
- }
- attr {
- name: "Ndense"
- type: "int"
- has_minimum: true
- }
- attr {
- name: "sparse_types"
- type: "list(type)"
- has_minimum: true
- allowed_values {
- list {
- type: DT_FLOAT
- type: DT_INT64
- type: DT_STRING
- }
- }
- }
- attr {
- name: "Tdense"
- type: "list(type)"
- has_minimum: true
- allowed_values {
- list {
- type: DT_FLOAT
- type: DT_INT64
- type: DT_STRING
- }
- }
- }
- attr {
- name: "dense_shapes"
- type: "list(shape)"
- has_minimum: true
- }
- }
- op {
name: "Placeholder"
output_arg {
name: "output"
@@ -752,22 +665,47 @@ meta_graphs {
}
}
op {
- name: "Range"
- input_arg {
- name: "start"
- type_attr: "Tidx"
- }
+ name: "Prod"
input_arg {
- name: "limit"
- type_attr: "Tidx"
+ name: "input"
+ type_attr: "T"
}
input_arg {
- name: "delta"
+ name: "reduction_indices"
type_attr: "Tidx"
}
output_arg {
name: "output"
- type_attr: "Tidx"
+ type_attr: "T"
+ }
+ attr {
+ name: "keep_dims"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
}
attr {
name: "Tidx"
@@ -777,8 +715,6 @@ meta_graphs {
}
allowed_values {
list {
- type: DT_FLOAT
- type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
}
@@ -786,15 +722,19 @@ meta_graphs {
}
}
op {
- name: "Reciprocal"
+ name: "RealDiv"
input_arg {
name: "x"
type_attr: "T"
}
- output_arg {
+ input_arg {
name: "y"
type_attr: "T"
}
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
attr {
name: "T"
type: "type"
@@ -803,6 +743,10 @@ meta_graphs {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
@@ -943,13 +887,54 @@ meta_graphs {
}
}
op {
- name: "Softmax"
+ name: "Slice"
input_arg {
- name: "logits"
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "begin"
+ type_attr: "Index"
+ }
+ input_arg {
+ name: "size"
+ type_attr: "Index"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Index"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "SoftmaxCrossEntropyWithLogits"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "labels"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "loss"
type_attr: "T"
}
output_arg {
- name: "softmax"
+ name: "backprop"
type_attr: "T"
}
attr {
@@ -1011,6 +996,10 @@ meta_graphs {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
@@ -1109,49 +1098,6 @@ meta_graphs {
}
}
op {
- name: "TopKV2"
- input_arg {
- name: "input"
- type_attr: "T"
- }
- input_arg {
- name: "k"
- type: DT_INT32
- }
- output_arg {
- name: "values"
- type_attr: "T"
- }
- output_arg {
- name: "indices"
- type: DT_INT32
- }
- attr {
- name: "sorted"
- type: "bool"
- default_value {
- b: true
- }
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_UINT16
- type: DT_HALF
- }
- }
- }
- }
- op {
name: "VariableV2"
output_arg {
name: "ref"
@@ -1182,224 +1128,27 @@ meta_graphs {
}
is_stateful: true
}
- }
- tags: "serve"
- tensorflow_version: "1.3.0"
- tensorflow_git_version: "v1.3.0-rc2-20-g0787eee"
- }
- graph_def {
- node {
- name: "tf_example"
- op: "Placeholder"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- unknown_rank: true
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "shape"
- value {
- shape {
- unknown_rank: true
- }
- }
- }
- }
- node {
- name: "ParseExample/Const"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- }
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_FLOAT
- tensor_shape {
- dim {
- }
- }
- }
- }
- }
- }
- node {
- name: "ParseExample/ParseExample/names"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- }
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- dim {
- }
- }
- }
- }
- }
- }
- node {
- name: "ParseExample/ParseExample/dense_keys_0"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- }
- string_val: "x"
- }
- }
- }
- }
- node {
- name: "ParseExample/ParseExample"
- op: "ParseExample"
- input: "tf_example"
- input: "ParseExample/ParseExample/names"
- input: "ParseExample/ParseExample/dense_keys_0"
- input: "ParseExample/Const"
- attr {
- key: "Ndense"
- value {
- i: 1
- }
- }
- attr {
- key: "Nsparse"
- value {
- i: 0
- }
- }
- attr {
- key: "Tdense"
- value {
- list {
- type: DT_FLOAT
- }
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: -1
- }
- dim {
- size: 784
- }
- }
- }
- }
- }
- attr {
- key: "dense_shapes"
- value {
- list {
- shape {
- dim {
- size: 784
- }
- }
- }
- }
- }
- attr {
- key: "sparse_types"
- value {
- list {
- }
+ op {
+ name: "ZerosLike"
+ input_arg {
+ name: "x"
+ type_attr: "T"
}
- }
- }
- node {
- name: "x"
- op: "Identity"
- input: "ParseExample/ParseExample"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
+ output_arg {
+ name: "y"
+ type_attr: "T"
}
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: -1
- }
- dim {
- size: 784
- }
- }
- }
+ attr {
+ name: "T"
+ type: "type"
}
}
}
+ tags: "serve"
+ tensorflow_version: "1.4.1"
+ tensorflow_git_version: "v1.4.0-19-ga52c8d9b01"
+ }
+ graph_def {
node {
name: "Placeholder"
op: "Placeholder"
@@ -1412,7 +1161,7 @@ meta_graphs {
size: -1
}
dim {
- size: 10
+ size: 784
}
}
}
@@ -1432,7 +1181,7 @@ meta_graphs {
size: -1
}
dim {
- size: 10
+ size: 784
}
}
}
@@ -1767,15 +1516,9 @@ meta_graphs {
}
}
node {
- name: "init"
- op: "NoOp"
- input: "^Variable/Assign"
- input: "^Variable_1/Assign"
- }
- node {
name: "MatMul"
op: "MatMul"
- input: "x"
+ input: "Placeholder"
input: "Variable/read"
attr {
key: "T"
@@ -1839,15 +1582,8 @@ meta_graphs {
}
}
node {
- name: "y"
- op: "Softmax"
- input: "add"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
+ name: "Placeholder_1"
+ op: "Placeholder"
attr {
key: "_output_shapes"
value {
@@ -1863,38 +1599,60 @@ meta_graphs {
}
}
}
- }
- node {
- name: "Log"
- op: "Log"
- input: "y"
attr {
- key: "T"
+ key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Rank"
+ op: "Const"
+ attr {
key: "_output_shapes"
value {
list {
shape {
- dim {
- size: -1
- }
- dim {
- size: 10
- }
}
}
}
}
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 2
+ }
+ }
+ }
}
node {
- name: "mul"
- op: "Mul"
- input: "Placeholder"
- input: "Log"
+ name: "Shape"
+ op: "Shape"
+ input: "add"
attr {
key: "T"
value {
@@ -1907,27 +1665,27 @@ meta_graphs {
list {
shape {
dim {
- size: -1
- }
- dim {
- size: 10
+ size: 2
}
}
}
}
}
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
}
node {
- name: "Const"
+ name: "Rank_1"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
- dim {
- size: 2
- }
}
}
}
@@ -1944,20 +1702,16 @@ meta_graphs {
tensor {
dtype: DT_INT32
tensor_shape {
- dim {
- size: 2
- }
}
- tensor_content: "\000\000\000\000\001\000\000\000"
+ int_val: 2
}
}
}
}
node {
- name: "Sum"
- op: "Sum"
- input: "mul"
- input: "Const"
+ name: "Shape_1"
+ op: "Shape"
+ input: "add"
attr {
key: "T"
value {
@@ -1965,11 +1719,27 @@ meta_graphs {
}
}
attr {
- key: "Tidx"
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
value {
type: DT_INT32
}
}
+ }
+ node {
+ name: "Sub/y"
+ op: "Const"
attr {
key: "_output_shapes"
value {
@@ -1980,20 +1750,32 @@ meta_graphs {
}
}
attr {
- key: "keep_dims"
+ key: "dtype"
value {
- b: false
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
}
}
}
node {
- name: "Neg"
- op: "Neg"
- input: "Sum"
+ name: "Sub"
+ op: "Sub"
+ input: "Rank_1"
+ input: "Sub/y"
attr {
key: "T"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
@@ -2007,46 +1789,51 @@ meta_graphs {
}
}
node {
- name: "gradients/Shape"
- op: "Const"
+ name: "Slice/begin"
+ op: "Pack"
+ input: "Sub"
attr {
- key: "_output_shapes"
+ key: "N"
value {
- list {
- shape {
- dim {
- }
- }
- }
+ i: 1
}
}
attr {
- key: "dtype"
+ key: "T"
value {
type: DT_INT32
}
}
attr {
- key: "value"
+ key: "_output_shapes"
value {
- tensor {
- dtype: DT_INT32
- tensor_shape {
+ list {
+ shape {
dim {
+ size: 1
}
}
}
}
}
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
}
node {
- name: "gradients/Const"
+ name: "Slice/size"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
+ dim {
+ size: 1
+ }
}
}
}
@@ -2054,50 +1841,40 @@ meta_graphs {
attr {
key: "dtype"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
- dtype: DT_FLOAT
+ dtype: DT_INT32
tensor_shape {
+ dim {
+ size: 1
+ }
}
- float_val: 1.0
+ int_val: 1
}
}
}
}
node {
- name: "gradients/Fill"
- op: "Fill"
- input: "gradients/Shape"
- input: "gradients/Const"
+ name: "Slice"
+ op: "Slice"
+ input: "Shape_1"
+ input: "Slice/begin"
+ input: "Slice/size"
attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "_output_shapes"
+ key: "Index"
value {
- list {
- shape {
- }
- }
+ type: DT_INT32
}
}
- }
- node {
- name: "gradients/Neg_grad/Neg"
- op: "Neg"
- input: "gradients/Fill"
attr {
key: "T"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
@@ -2105,13 +1882,16 @@ meta_graphs {
value {
list {
shape {
+ dim {
+ size: 1
+ }
}
}
}
}
}
node {
- name: "gradients/Sum_grad/Reshape/shape"
+ name: "concat/values_0"
op: "Const"
attr {
key: "_output_shapes"
@@ -2119,7 +1899,7 @@ meta_graphs {
list {
shape {
dim {
- size: 2
+ size: 1
}
}
}
@@ -2138,55 +1918,66 @@ meta_graphs {
dtype: DT_INT32
tensor_shape {
dim {
- size: 2
+ size: 1
}
}
- tensor_content: "\001\000\000\000\001\000\000\000"
+ int_val: -1
}
}
}
}
node {
- name: "gradients/Sum_grad/Reshape"
- op: "Reshape"
- input: "gradients/Neg_grad/Neg"
- input: "gradients/Sum_grad/Reshape/shape"
+ name: "concat/axis"
+ op: "Const"
attr {
- key: "T"
+ key: "_output_shapes"
value {
- type: DT_FLOAT
+ list {
+ shape {
+ }
+ }
}
}
attr {
- key: "Tshape"
+ key: "dtype"
value {
type: DT_INT32
}
}
attr {
- key: "_output_shapes"
+ key: "value"
value {
- list {
- shape {
- dim {
- size: 1
- }
- dim {
- size: 1
- }
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
}
+ int_val: 0
}
}
}
}
node {
- name: "gradients/Sum_grad/Shape"
- op: "Shape"
- input: "mul"
+ name: "concat"
+ op: "ConcatV2"
+ input: "concat/values_0"
+ input: "Slice"
+ input: "concat/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
attr {
key: "T"
value {
- type: DT_FLOAT
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
}
}
attr {
@@ -2201,18 +1992,12 @@ meta_graphs {
}
}
}
- attr {
- key: "out_type"
- value {
- type: DT_INT32
- }
- }
}
node {
- name: "gradients/Sum_grad/Tile"
- op: "Tile"
- input: "gradients/Sum_grad/Reshape"
- input: "gradients/Sum_grad/Shape"
+ name: "Reshape"
+ op: "Reshape"
+ input: "add"
+ input: "concat"
attr {
key: "T"
value {
@@ -2220,7 +2005,7 @@ meta_graphs {
}
}
attr {
- key: "Tmultiples"
+ key: "Tshape"
value {
type: DT_INT32
}
@@ -2234,7 +2019,7 @@ meta_graphs {
size: -1
}
dim {
- size: 10
+ size: -1
}
}
}
@@ -2242,38 +2027,39 @@ meta_graphs {
}
}
node {
- name: "gradients/mul_grad/Shape"
- op: "Shape"
- input: "Placeholder"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
+ name: "Rank_2"
+ op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
- dim {
- size: 2
- }
}
}
}
}
attr {
- key: "out_type"
+ key: "dtype"
value {
type: DT_INT32
}
}
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 2
+ }
+ }
+ }
}
node {
- name: "gradients/mul_grad/Shape_1"
+ name: "Shape_2"
op: "Shape"
- input: "Log"
+ input: "Placeholder_1"
attr {
key: "T"
value {
@@ -2300,43 +2086,44 @@ meta_graphs {
}
}
node {
- name: "gradients/mul_grad/BroadcastGradientArgs"
- op: "BroadcastGradientArgs"
- input: "gradients/mul_grad/Shape"
- input: "gradients/mul_grad/Shape_1"
+ name: "Sub_1/y"
+ op: "Const"
attr {
- key: "T"
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
value {
type: DT_INT32
}
}
attr {
- key: "_output_shapes"
+ key: "value"
value {
- list {
- shape {
- dim {
- size: -1
- }
- }
- shape {
- dim {
- size: -1
- }
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
}
+ int_val: 1
}
}
}
}
node {
- name: "gradients/mul_grad/mul"
- op: "Mul"
- input: "gradients/Sum_grad/Tile"
- input: "Log"
+ name: "Sub_1"
+ op: "Sub"
+ input: "Rank_2"
+ input: "Sub_1/y"
attr {
key: "T"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
@@ -2344,30 +2131,23 @@ meta_graphs {
value {
list {
shape {
- dim {
- size: -1
- }
- dim {
- size: 10
- }
}
}
}
}
}
node {
- name: "gradients/mul_grad/Sum"
- op: "Sum"
- input: "gradients/mul_grad/mul"
- input: "gradients/mul_grad/BroadcastGradientArgs"
+ name: "Slice_1/begin"
+ op: "Pack"
+ input: "Sub_1"
attr {
- key: "T"
+ key: "N"
value {
- type: DT_FLOAT
+ i: 1
}
}
attr {
- key: "Tidx"
+ key: "T"
value {
type: DT_INT32
}
@@ -2377,60 +2157,72 @@ meta_graphs {
value {
list {
shape {
- unknown_rank: true
+ dim {
+ size: 1
+ }
}
}
}
}
attr {
- key: "keep_dims"
+ key: "axis"
value {
- b: false
+ i: 0
}
}
}
node {
- name: "gradients/mul_grad/Reshape"
- op: "Reshape"
- input: "gradients/mul_grad/Sum"
- input: "gradients/mul_grad/Shape"
+ name: "Slice_1/size"
+ op: "Const"
attr {
- key: "T"
+ key: "_output_shapes"
value {
- type: DT_FLOAT
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
}
}
attr {
- key: "Tshape"
+ key: "dtype"
value {
type: DT_INT32
}
}
attr {
- key: "_output_shapes"
+ key: "value"
value {
- list {
- shape {
- dim {
- size: -1
- }
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
dim {
- size: 10
+ size: 1
}
}
+ int_val: 1
}
}
}
}
node {
- name: "gradients/mul_grad/mul_1"
- op: "Mul"
- input: "Placeholder"
- input: "gradients/Sum_grad/Tile"
+ name: "Slice_1"
+ op: "Slice"
+ input: "Shape_2"
+ input: "Slice_1/begin"
+ input: "Slice_1/size"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
attr {
key: "T"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
@@ -2439,10 +2231,7 @@ meta_graphs {
list {
shape {
dim {
- size: -1
- }
- dim {
- size: 10
+ size: 1
}
}
}
@@ -2450,44 +2239,113 @@ meta_graphs {
}
}
node {
- name: "gradients/mul_grad/Sum_1"
- op: "Sum"
- input: "gradients/mul_grad/mul_1"
- input: "gradients/mul_grad/BroadcastGradientArgs:1"
+ name: "concat_1/values_0"
+ op: "Const"
attr {
- key: "T"
+ key: "_output_shapes"
value {
- type: DT_FLOAT
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
}
}
attr {
- key: "Tidx"
+ key: "dtype"
value {
type: DT_INT32
}
}
attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: -1
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1/axis"
+ op: "Const"
+ attr {
key: "_output_shapes"
value {
list {
shape {
- unknown_rank: true
}
}
}
}
attr {
- key: "keep_dims"
+ key: "dtype"
value {
- b: false
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1"
+ op: "ConcatV2"
+ input: "concat_1/values_0"
+ input: "Slice_1"
+ input: "concat_1/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
}
}
}
node {
- name: "gradients/mul_grad/Reshape_1"
+ name: "Reshape_1"
op: "Reshape"
- input: "gradients/mul_grad/Sum_1"
- input: "gradients/mul_grad/Shape_1"
+ input: "Placeholder_1"
+ input: "concat_1"
attr {
key: "T"
value {
@@ -2509,7 +2367,7 @@ meta_graphs {
size: -1
}
dim {
- size: 10
+ size: -1
}
}
}
@@ -2517,16 +2375,10 @@ meta_graphs {
}
}
node {
- name: "gradients/mul_grad/tuple/group_deps"
- op: "NoOp"
- input: "^gradients/mul_grad/Reshape"
- input: "^gradients/mul_grad/Reshape_1"
- }
- node {
- name: "gradients/mul_grad/tuple/control_dependency"
- op: "Identity"
- input: "gradients/mul_grad/Reshape"
- input: "^gradients/mul_grad/tuple/group_deps"
+ name: "SoftmaxCrossEntropyWithLogits"
+ op: "SoftmaxCrossEntropyWithLogits"
+ input: "Reshape"
+ input: "Reshape_1"
attr {
key: "T"
value {
@@ -2534,73 +2386,127 @@ meta_graphs {
}
}
attr {
- key: "_class"
+ key: "_output_shapes"
value {
list {
- s: "loc:@gradients/mul_grad/Reshape"
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
}
}
}
+ }
+ node {
+ name: "Sub_2/y"
+ op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
- dim {
- size: -1
- }
- dim {
- size: 10
- }
}
}
}
}
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
}
node {
- name: "gradients/mul_grad/tuple/control_dependency_1"
- op: "Identity"
- input: "gradients/mul_grad/Reshape_1"
- input: "^gradients/mul_grad/tuple/group_deps"
+ name: "Sub_2"
+ op: "Sub"
+ input: "Rank"
+ input: "Sub_2/y"
attr {
key: "T"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
- key: "_class"
+ key: "_output_shapes"
value {
list {
- s: "loc:@gradients/mul_grad/Reshape_1"
+ shape {
+ }
}
}
}
+ }
+ node {
+ name: "Slice_2/begin"
+ op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
- size: -1
+ size: 1
}
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
dim {
- size: 10
+ size: 1
}
}
+ int_val: 0
}
}
}
}
node {
- name: "gradients/Log_grad/Reciprocal"
- op: "Reciprocal"
- input: "y"
- input: "^gradients/mul_grad/tuple/control_dependency_1"
+ name: "Slice_2/size"
+ op: "Pack"
+ input: "Sub_2"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
attr {
key: "T"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
@@ -2609,25 +2515,35 @@ meta_graphs {
list {
shape {
dim {
- size: -1
- }
- dim {
- size: 10
+ size: 1
}
}
}
}
}
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
}
node {
- name: "gradients/Log_grad/mul"
- op: "Mul"
- input: "gradients/mul_grad/tuple/control_dependency_1"
- input: "gradients/Log_grad/Reciprocal"
+ name: "Slice_2"
+ op: "Slice"
+ input: "Shape"
+ input: "Slice_2/begin"
+ input: "Slice_2/size"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
attr {
key: "T"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
@@ -2638,19 +2554,16 @@ meta_graphs {
dim {
size: -1
}
- dim {
- size: 10
- }
}
}
}
}
}
node {
- name: "gradients/y_grad/mul"
- op: "Mul"
- input: "gradients/Log_grad/mul"
- input: "y"
+ name: "Reshape_2"
+ op: "Reshape"
+ input: "SoftmaxCrossEntropyWithLogits"
+ input: "Slice_2"
attr {
key: "T"
value {
@@ -2658,6 +2571,12 @@ meta_graphs {
}
}
attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
key: "_output_shapes"
value {
list {
@@ -2665,16 +2584,13 @@ meta_graphs {
dim {
size: -1
}
- dim {
- size: 10
- }
}
}
}
}
}
node {
- name: "gradients/y_grad/Sum/reduction_indices"
+ name: "Const"
op: "Const"
attr {
key: "_output_shapes"
@@ -2704,16 +2620,16 @@ meta_graphs {
size: 1
}
}
- int_val: 1
+ int_val: 0
}
}
}
}
node {
- name: "gradients/y_grad/Sum"
- op: "Sum"
- input: "gradients/y_grad/mul"
- input: "gradients/y_grad/Sum/reduction_indices"
+ name: "Mean"
+ op: "Mean"
+ input: "Reshape_2"
+ input: "Const"
attr {
key: "T"
value {
@@ -2731,9 +2647,6 @@ meta_graphs {
value {
list {
shape {
- dim {
- size: -1
- }
}
}
}
@@ -2746,7 +2659,7 @@ meta_graphs {
}
}
node {
- name: "gradients/y_grad/Reshape/shape"
+ name: "gradients/Shape"
op: "Const"
attr {
key: "_output_shapes"
@@ -2754,7 +2667,6 @@ meta_graphs {
list {
shape {
dim {
- size: 2
}
}
}
@@ -2773,19 +2685,47 @@ meta_graphs {
dtype: DT_INT32
tensor_shape {
dim {
- size: 2
}
}
- tensor_content: "\377\377\377\377\001\000\000\000"
}
}
}
}
node {
- name: "gradients/y_grad/Reshape"
- op: "Reshape"
- input: "gradients/y_grad/Sum"
- input: "gradients/y_grad/Reshape/shape"
+ name: "gradients/Const"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Fill"
+ op: "Fill"
+ input: "gradients/Shape"
+ input: "gradients/Const"
attr {
key: "T"
value {
@@ -2793,32 +2733,56 @@ meta_graphs {
}
}
attr {
- key: "Tshape"
+ key: "_output_shapes"
value {
- type: DT_INT32
+ list {
+ shape {
+ }
+ }
}
}
+ }
+ node {
+ name: "gradients/Mean_grad/Reshape/shape"
+ op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
- size: -1
+ size: 1
}
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
dim {
size: 1
}
}
+ int_val: 1
}
}
}
}
node {
- name: "gradients/y_grad/sub"
- op: "Sub"
- input: "gradients/Log_grad/mul"
- input: "gradients/y_grad/Reshape"
+ name: "gradients/Mean_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/Fill"
+ input: "gradients/Mean_grad/Reshape/shape"
attr {
key: "T"
value {
@@ -2826,26 +2790,58 @@ meta_graphs {
}
}
attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
key: "_output_shapes"
value {
list {
shape {
dim {
- size: -1
+ size: 1
}
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Shape"
+ op: "Shape"
+ input: "Reshape_2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
dim {
- size: 10
+ size: 1
}
}
}
}
}
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
}
node {
- name: "gradients/y_grad/mul_1"
- op: "Mul"
- input: "gradients/y_grad/sub"
- input: "y"
+ name: "gradients/Mean_grad/Tile"
+ op: "Tile"
+ input: "gradients/Mean_grad/Reshape"
+ input: "gradients/Mean_grad/Shape"
attr {
key: "T"
value {
@@ -2853,6 +2849,12 @@ meta_graphs {
}
}
attr {
+ key: "Tmultiples"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
key: "_output_shapes"
value {
list {
@@ -2860,18 +2862,15 @@ meta_graphs {
dim {
size: -1
}
- dim {
- size: 10
- }
}
}
}
}
}
node {
- name: "gradients/add_grad/Shape"
+ name: "gradients/Mean_grad/Shape_1"
op: "Shape"
- input: "MatMul"
+ input: "Reshape_2"
attr {
key: "T"
value {
@@ -2884,7 +2883,7 @@ meta_graphs {
list {
shape {
dim {
- size: 2
+ size: 1
}
}
}
@@ -2898,7 +2897,7 @@ meta_graphs {
}
}
node {
- name: "gradients/add_grad/Shape_1"
+ name: "gradients/Mean_grad/Shape_2"
op: "Const"
attr {
key: "_output_shapes"
@@ -2906,7 +2905,6 @@ meta_graphs {
list {
shape {
dim {
- size: 1
}
}
}
@@ -2925,23 +2923,21 @@ meta_graphs {
dtype: DT_INT32
tensor_shape {
dim {
- size: 1
}
}
- int_val: 10
}
}
}
}
node {
- name: "gradients/add_grad/BroadcastGradientArgs"
- op: "BroadcastGradientArgs"
- input: "gradients/add_grad/Shape"
- input: "gradients/add_grad/Shape_1"
+ name: "gradients/Mean_grad/Const"
+ op: "Const"
attr {
- key: "T"
+ key: "_class"
value {
- type: DT_INT32
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
}
}
attr {
@@ -2950,27 +2946,42 @@ meta_graphs {
list {
shape {
dim {
- size: -1
+ size: 1
}
}
- shape {
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
dim {
- size: -1
+ size: 1
}
}
+ int_val: 0
}
}
}
}
node {
- name: "gradients/add_grad/Sum"
- op: "Sum"
- input: "gradients/y_grad/mul_1"
- input: "gradients/add_grad/BroadcastGradientArgs"
+ name: "gradients/Mean_grad/Prod"
+ op: "Prod"
+ input: "gradients/Mean_grad/Shape_1"
+ input: "gradients/Mean_grad/Const"
attr {
key: "T"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
@@ -2980,11 +2991,18 @@ meta_graphs {
}
}
attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
key: "_output_shapes"
value {
list {
shape {
- unknown_rank: true
}
}
}
@@ -2997,20 +3015,14 @@ meta_graphs {
}
}
node {
- name: "gradients/add_grad/Reshape"
- op: "Reshape"
- input: "gradients/add_grad/Sum"
- input: "gradients/add_grad/Shape"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
+ name: "gradients/Mean_grad/Const_1"
+ op: "Const"
attr {
- key: "Tshape"
+ key: "_class"
value {
- type: DT_INT32
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
}
}
attr {
@@ -3019,25 +3031,42 @@ meta_graphs {
list {
shape {
dim {
- size: -1
+ size: 1
}
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
dim {
- size: 10
+ size: 1
}
}
+ int_val: 0
}
}
}
}
node {
- name: "gradients/add_grad/Sum_1"
- op: "Sum"
- input: "gradients/y_grad/mul_1"
- input: "gradients/add_grad/BroadcastGradientArgs:1"
+ name: "gradients/Mean_grad/Prod_1"
+ op: "Prod"
+ input: "gradients/Mean_grad/Shape_2"
+ input: "gradients/Mean_grad/Const_1"
attr {
key: "T"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
@@ -3047,11 +3076,18 @@ meta_graphs {
}
}
attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
key: "_output_shapes"
value {
list {
shape {
- unknown_rank: true
}
}
}
@@ -3064,57 +3100,59 @@ meta_graphs {
}
}
node {
- name: "gradients/add_grad/Reshape_1"
- op: "Reshape"
- input: "gradients/add_grad/Sum_1"
- input: "gradients/add_grad/Shape_1"
+ name: "gradients/Mean_grad/Maximum/y"
+ op: "Const"
attr {
- key: "T"
+ key: "_class"
value {
- type: DT_FLOAT
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
}
}
attr {
- key: "Tshape"
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
value {
type: DT_INT32
}
}
attr {
- key: "_output_shapes"
+ key: "value"
value {
- list {
- shape {
- dim {
- size: 10
- }
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
}
+ int_val: 1
}
}
}
}
node {
- name: "gradients/add_grad/tuple/group_deps"
- op: "NoOp"
- input: "^gradients/add_grad/Reshape"
- input: "^gradients/add_grad/Reshape_1"
- }
- node {
- name: "gradients/add_grad/tuple/control_dependency"
- op: "Identity"
- input: "gradients/add_grad/Reshape"
- input: "^gradients/add_grad/tuple/group_deps"
+ name: "gradients/Mean_grad/Maximum"
+ op: "Maximum"
+ input: "gradients/Mean_grad/Prod_1"
+ input: "gradients/Mean_grad/Maximum/y"
attr {
key: "T"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
key: "_class"
value {
list {
- s: "loc:@gradients/add_grad/Reshape"
+ s: "loc:@gradients/Mean_grad/Shape_1"
}
}
}
@@ -3123,33 +3161,27 @@ meta_graphs {
value {
list {
shape {
- dim {
- size: -1
- }
- dim {
- size: 10
- }
}
}
}
}
}
node {
- name: "gradients/add_grad/tuple/control_dependency_1"
- op: "Identity"
- input: "gradients/add_grad/Reshape_1"
- input: "^gradients/add_grad/tuple/group_deps"
+ name: "gradients/Mean_grad/floordiv"
+ op: "FloorDiv"
+ input: "gradients/Mean_grad/Prod"
+ input: "gradients/Mean_grad/Maximum"
attr {
key: "T"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
key: "_class"
value {
list {
- s: "loc:@gradients/add_grad/Reshape_1"
+ s: "loc:@gradients/Mean_grad/Shape_1"
}
}
}
@@ -3158,58 +3190,65 @@ meta_graphs {
value {
list {
shape {
- dim {
- size: 10
- }
}
}
}
}
}
node {
- name: "gradients/MatMul_grad/MatMul"
- op: "MatMul"
- input: "gradients/add_grad/tuple/control_dependency"
- input: "Variable/read"
+ name: "gradients/Mean_grad/Cast"
+ op: "Cast"
+ input: "gradients/Mean_grad/floordiv"
attr {
- key: "T"
+ key: "DstT"
value {
type: DT_FLOAT
}
}
attr {
+ key: "SrcT"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
key: "_output_shapes"
value {
list {
shape {
- dim {
- size: -1
- }
- dim {
- size: 784
- }
}
}
}
}
+ }
+ node {
+ name: "gradients/Mean_grad/truediv"
+ op: "RealDiv"
+ input: "gradients/Mean_grad/Tile"
+ input: "gradients/Mean_grad/Cast"
attr {
- key: "transpose_a"
+ key: "T"
value {
- b: false
+ type: DT_FLOAT
}
}
attr {
- key: "transpose_b"
+ key: "_output_shapes"
value {
- b: true
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
}
}
}
node {
- name: "gradients/MatMul_grad/MatMul_1"
- op: "MatMul"
- input: "x"
- input: "gradients/add_grad/tuple/control_dependency"
+ name: "gradients/Reshape_2_grad/Shape"
+ op: "Shape"
+ input: "SoftmaxCrossEntropyWithLogits"
attr {
key: "T"
value {
@@ -3222,39 +3261,24 @@ meta_graphs {
list {
shape {
dim {
- size: 784
- }
- dim {
- size: 10
+ size: 1
}
}
}
}
}
attr {
- key: "transpose_a"
- value {
- b: true
- }
- }
- attr {
- key: "transpose_b"
+ key: "out_type"
value {
- b: false
+ type: DT_INT32
}
}
}
node {
- name: "gradients/MatMul_grad/tuple/group_deps"
- op: "NoOp"
- input: "^gradients/MatMul_grad/MatMul"
- input: "^gradients/MatMul_grad/MatMul_1"
- }
- node {
- name: "gradients/MatMul_grad/tuple/control_dependency"
- op: "Identity"
- input: "gradients/MatMul_grad/MatMul"
- input: "^gradients/MatMul_grad/tuple/group_deps"
+ name: "gradients/Reshape_2_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/Mean_grad/truediv"
+ input: "gradients/Reshape_2_grad/Shape"
attr {
key: "T"
value {
@@ -3262,11 +3286,9 @@ meta_graphs {
}
}
attr {
- key: "_class"
+ key: "Tshape"
value {
- list {
- s: "loc:@gradients/MatMul_grad/MatMul"
- }
+ type: DT_INT32
}
}
attr {
@@ -3277,19 +3299,15 @@ meta_graphs {
dim {
size: -1
}
- dim {
- size: 784
- }
}
}
}
}
}
node {
- name: "gradients/MatMul_grad/tuple/control_dependency_1"
- op: "Identity"
- input: "gradients/MatMul_grad/MatMul_1"
- input: "^gradients/MatMul_grad/tuple/group_deps"
+ name: "gradients/zeros_like"
+ op: "ZerosLike"
+ input: "SoftmaxCrossEntropyWithLogits:1"
attr {
key: "T"
value {
@@ -3297,23 +3315,15 @@ meta_graphs {
}
}
attr {
- key: "_class"
- value {
- list {
- s: "loc:@gradients/MatMul_grad/MatMul_1"
- }
- }
- }
- attr {
key: "_output_shapes"
value {
list {
shape {
dim {
- size: 784
+ size: -1
}
dim {
- size: 10
+ size: -1
}
}
}
@@ -3321,7 +3331,7 @@ meta_graphs {
}
}
node {
- name: "GradientDescent/learning_rate"
+ name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim"
op: "Const"
attr {
key: "_output_shapes"
@@ -3335,27 +3345,26 @@ meta_graphs {
attr {
key: "dtype"
value {
- type: DT_FLOAT
+ type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
- dtype: DT_FLOAT
+ dtype: DT_INT32
tensor_shape {
}
- float_val: 0.00999999977648
+ int_val: -1
}
}
}
}
node {
- name: "GradientDescent/update_Variable/ApplyGradientDescent"
- op: "ApplyGradientDescent"
- input: "Variable"
- input: "GradientDescent/learning_rate"
- input: "gradients/MatMul_grad/tuple/control_dependency_1"
+ name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims"
+ op: "ExpandDims"
+ input: "gradients/Reshape_2_grad/Reshape"
+ input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim"
attr {
key: "T"
value {
@@ -3363,11 +3372,9 @@ meta_graphs {
}
}
attr {
- key: "_class"
+ key: "Tdim"
value {
- list {
- s: "loc:@Variable"
- }
+ type: DT_INT32
}
}
attr {
@@ -3376,28 +3383,21 @@ meta_graphs {
list {
shape {
dim {
- size: 784
+ size: -1
}
dim {
- size: 10
+ size: 1
}
}
}
}
}
- attr {
- key: "use_locking"
- value {
- b: false
- }
- }
}
node {
- name: "GradientDescent/update_Variable_1/ApplyGradientDescent"
- op: "ApplyGradientDescent"
- input: "Variable_1"
- input: "GradientDescent/learning_rate"
- input: "gradients/add_grad/tuple/control_dependency_1"
+ name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul"
+ op: "Mul"
+ input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims"
+ input: "SoftmaxCrossEntropyWithLogits:1"
attr {
key: "T"
value {
@@ -3405,73 +3405,87 @@ meta_graphs {
}
}
attr {
- key: "_class"
+ key: "_output_shapes"
value {
list {
- s: "loc:@Variable_1"
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
}
}
}
+ }
+ node {
+ name: "gradients/Reshape_grad/Shape"
+ op: "Shape"
+ input: "add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
- size: 10
+ size: 2
}
}
}
}
}
attr {
- key: "use_locking"
+ key: "out_type"
value {
- b: false
+ type: DT_INT32
}
}
}
node {
- name: "GradientDescent"
- op: "NoOp"
- input: "^GradientDescent/update_Variable/ApplyGradientDescent"
- input: "^GradientDescent/update_Variable_1/ApplyGradientDescent"
- }
- node {
- name: "TopKV2/k"
- op: "Const"
+ name: "gradients/Reshape_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul"
+ input: "gradients/Reshape_grad/Shape"
attr {
- key: "_output_shapes"
+ key: "T"
value {
- list {
- shape {
- }
- }
+ type: DT_FLOAT
}
}
attr {
- key: "dtype"
+ key: "Tshape"
value {
type: DT_INT32
}
}
attr {
- key: "value"
+ key: "_output_shapes"
value {
- tensor {
- dtype: DT_INT32
- tensor_shape {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
}
- int_val: 10
}
}
}
}
node {
- name: "TopKV2"
- op: "TopKV2"
- input: "y"
- input: "TopKV2/k"
+ name: "gradients/add_grad/Shape"
+ op: "Shape"
+ input: "MatMul"
attr {
key: "T"
value {
@@ -3484,32 +3498,21 @@ meta_graphs {
list {
shape {
dim {
- size: -1
- }
- dim {
- size: 10
- }
- }
- shape {
- dim {
- size: -1
- }
- dim {
- size: 10
+ size: 2
}
}
}
}
}
attr {
- key: "sorted"
+ key: "out_type"
value {
- b: true
+ type: DT_INT32
}
}
}
node {
- name: "Const_1"
+ name: "gradients/add_grad/Shape_1"
op: "Const"
attr {
key: "_output_shapes"
@@ -3517,7 +3520,7 @@ meta_graphs {
list {
shape {
dim {
- size: 10
+ size: 1
}
}
}
@@ -3526,133 +3529,207 @@ meta_graphs {
attr {
key: "dtype"
value {
- type: DT_STRING
+ type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
- dtype: DT_STRING
+ dtype: DT_INT32
tensor_shape {
dim {
- size: 10
+ size: 1
}
}
- string_val: "0"
- string_val: "1"
- string_val: "2"
- string_val: "3"
- string_val: "4"
- string_val: "5"
- string_val: "6"
- string_val: "7"
- string_val: "8"
- string_val: "9"
+ int_val: 10
}
}
}
}
node {
- name: "index_to_string/Size"
- op: "Const"
+ name: "gradients/add_grad/BroadcastGradientArgs"
+ op: "BroadcastGradientArgs"
+ input: "gradients/add_grad/Shape"
+ input: "gradients/add_grad/Shape_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
attr {
key: "_output_shapes"
value {
list {
shape {
+ dim {
+ size: -1
+ }
+ }
+ shape {
+ dim {
+ size: -1
+ }
}
}
}
}
+ }
+ node {
+ name: "gradients/add_grad/Sum"
+ op: "Sum"
+ input: "gradients/Reshape_grad/Reshape"
+ input: "gradients/add_grad/BroadcastGradientArgs"
attr {
- key: "dtype"
+ key: "T"
value {
- type: DT_INT32
+ type: DT_FLOAT
}
}
attr {
- key: "value"
+ key: "Tidx"
value {
- tensor {
- dtype: DT_INT32
- tensor_shape {
- }
- int_val: 10
- }
+ type: DT_INT32
}
}
- }
- node {
- name: "index_to_string/range/start"
- op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
+ unknown_rank: true
}
}
}
}
attr {
- key: "dtype"
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/add_grad/Sum"
+ input: "gradients/add_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
value {
type: DT_INT32
}
}
attr {
- key: "value"
+ key: "_output_shapes"
value {
- tensor {
- dtype: DT_INT32
- tensor_shape {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
}
- int_val: 0
}
}
}
}
node {
- name: "index_to_string/range/delta"
- op: "Const"
+ name: "gradients/add_grad/Sum_1"
+ op: "Sum"
+ input: "gradients/Reshape_grad/Reshape"
+ input: "gradients/add_grad/BroadcastGradientArgs:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
attr {
key: "_output_shapes"
value {
list {
shape {
+ unknown_rank: true
}
}
}
}
attr {
- key: "dtype"
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Reshape_1"
+ op: "Reshape"
+ input: "gradients/add_grad/Sum_1"
+ input: "gradients/add_grad/Shape_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
value {
type: DT_INT32
}
}
attr {
- key: "value"
+ key: "_output_shapes"
value {
- tensor {
- dtype: DT_INT32
- tensor_shape {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
}
- int_val: 1
}
}
}
}
node {
- name: "index_to_string/range"
- op: "Range"
- input: "index_to_string/range/start"
- input: "index_to_string/Size"
- input: "index_to_string/range/delta"
+ name: "gradients/add_grad/tuple/group_deps"
+ op: "NoOp"
+ input: "^gradients/add_grad/Reshape"
+ input: "^gradients/add_grad/Reshape_1"
+ }
+ node {
+ name: "gradients/add_grad/tuple/control_dependency"
+ op: "Identity"
+ input: "gradients/add_grad/Reshape"
+ input: "^gradients/add_grad/tuple/group_deps"
attr {
- key: "Tidx"
+ key: "T"
value {
- type: DT_INT32
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/add_grad/Reshape"
+ }
}
}
attr {
@@ -3661,6 +3738,9 @@ meta_graphs {
list {
shape {
dim {
+ size: -1
+ }
+ dim {
size: 10
}
}
@@ -3669,19 +3749,22 @@ meta_graphs {
}
}
node {
- name: "index_to_string/ToInt64"
- op: "Cast"
- input: "index_to_string/range"
+ name: "gradients/add_grad/tuple/control_dependency_1"
+ op: "Identity"
+ input: "gradients/add_grad/Reshape_1"
+ input: "^gradients/add_grad/tuple/group_deps"
attr {
- key: "DstT"
+ key: "T"
value {
- type: DT_INT64
+ type: DT_FLOAT
}
}
attr {
- key: "SrcT"
+ key: "_class"
value {
- type: DT_INT32
+ list {
+ s: "loc:@gradients/add_grad/Reshape_1"
+ }
}
}
attr {
@@ -3698,111 +3781,207 @@ meta_graphs {
}
}
node {
- name: "index_to_string"
- op: "HashTableV2"
+ name: "gradients/MatMul_grad/MatMul"
+ op: "MatMul"
+ input: "gradients/add_grad/tuple/control_dependency"
+ input: "Variable/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
attr {
key: "_output_shapes"
value {
list {
shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
}
}
}
}
attr {
- key: "container"
+ key: "transpose_a"
value {
- s: ""
+ b: false
}
}
attr {
- key: "key_dtype"
+ key: "transpose_b"
value {
- type: DT_INT64
+ b: true
}
}
+ }
+ node {
+ name: "gradients/MatMul_grad/MatMul_1"
+ op: "MatMul"
+ input: "Placeholder"
+ input: "gradients/add_grad/tuple/control_dependency"
attr {
- key: "shared_name"
+ key: "T"
value {
- s: ""
+ type: DT_FLOAT
}
}
attr {
- key: "use_node_name_sharing"
+ key: "_output_shapes"
value {
- b: false
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
}
}
attr {
- key: "value_dtype"
+ key: "transpose_a"
value {
- type: DT_STRING
+ b: true
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
}
}
}
node {
- name: "index_to_string/Const"
- op: "Const"
+ name: "gradients/MatMul_grad/tuple/group_deps"
+ op: "NoOp"
+ input: "^gradients/MatMul_grad/MatMul"
+ input: "^gradients/MatMul_grad/MatMul_1"
+ }
+ node {
+ name: "gradients/MatMul_grad/tuple/control_dependency"
+ op: "Identity"
+ input: "gradients/MatMul_grad/MatMul"
+ input: "^gradients/MatMul_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/MatMul_grad/MatMul"
+ }
+ }
+ }
attr {
key: "_output_shapes"
value {
list {
shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
}
}
}
}
+ }
+ node {
+ name: "gradients/MatMul_grad/tuple/control_dependency_1"
+ op: "Identity"
+ input: "gradients/MatMul_grad/MatMul_1"
+ input: "^gradients/MatMul_grad/tuple/group_deps"
attr {
- key: "dtype"
+ key: "T"
value {
- type: DT_STRING
+ type: DT_FLOAT
}
}
attr {
- key: "value"
+ key: "_class"
value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
+ list {
+ s: "loc:@gradients/MatMul_grad/MatMul_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
}
- string_val: "UNK"
}
}
}
}
node {
- name: "index_to_string/table_init"
- op: "InitializeTableV2"
- input: "index_to_string"
- input: "index_to_string/ToInt64"
- input: "Const_1"
+ name: "GradientDescent/learning_rate"
+ op: "Const"
attr {
- key: "Tkey"
+ key: "_output_shapes"
value {
- type: DT_INT64
+ list {
+ shape {
+ }
+ }
}
}
attr {
- key: "Tval"
+ key: "dtype"
value {
- type: DT_STRING
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.5
+ }
}
}
}
node {
- name: "ToInt64"
- op: "Cast"
- input: "TopKV2:1"
+ name: "GradientDescent/update_Variable/ApplyGradientDescent"
+ op: "ApplyGradientDescent"
+ input: "Variable"
+ input: "GradientDescent/learning_rate"
+ input: "gradients/MatMul_grad/tuple/control_dependency_1"
attr {
- key: "DstT"
+ key: "T"
value {
- type: DT_INT64
+ type: DT_FLOAT
}
}
attr {
- key: "SrcT"
+ key: "_class"
value {
- type: DT_INT32
+ list {
+ s: "loc:@Variable"
+ }
}
}
attr {
@@ -3811,7 +3990,7 @@ meta_graphs {
list {
shape {
dim {
- size: -1
+ size: 784
}
dim {
size: 10
@@ -3820,23 +3999,31 @@ meta_graphs {
}
}
}
+ attr {
+ key: "use_locking"
+ value {
+ b: false
+ }
+ }
}
node {
- name: "index_to_string_Lookup"
- op: "LookupTableFindV2"
- input: "index_to_string"
- input: "ToInt64"
- input: "index_to_string/Const"
+ name: "GradientDescent/update_Variable_1/ApplyGradientDescent"
+ op: "ApplyGradientDescent"
+ input: "Variable_1"
+ input: "GradientDescent/learning_rate"
+ input: "gradients/add_grad/tuple/control_dependency_1"
attr {
- key: "Tin"
+ key: "T"
value {
- type: DT_INT64
+ type: DT_FLOAT
}
}
attr {
- key: "Tout"
+ key: "_class"
value {
- type: DT_STRING
+ list {
+ s: "loc:@Variable_1"
+ }
}
}
attr {
@@ -3845,15 +4032,30 @@ meta_graphs {
list {
shape {
dim {
- size: -1
- }
- dim {
size: 10
}
}
}
}
}
+ attr {
+ key: "use_locking"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "GradientDescent"
+ op: "NoOp"
+ input: "^GradientDescent/update_Variable/ApplyGradientDescent"
+ input: "^GradientDescent/update_Variable_1/ApplyGradientDescent"
+ }
+ node {
+ name: "init"
+ op: "NoOp"
+ input: "^Variable/Assign"
+ input: "^Variable_1/Assign"
}
node {
name: "ArgMax/dimension"
@@ -3888,7 +4090,7 @@ meta_graphs {
node {
name: "ArgMax"
op: "ArgMax"
- input: "y"
+ input: "add"
input: "ArgMax/dimension"
attr {
key: "T"
@@ -3954,7 +4156,7 @@ meta_graphs {
node {
name: "ArgMax_1"
op: "ArgMax"
- input: "Placeholder"
+ input: "Placeholder_1"
input: "ArgMax_1/dimension"
attr {
key: "T"
@@ -4012,7 +4214,7 @@ meta_graphs {
}
}
node {
- name: "Cast"
+ name: "Cast_1"
op: "Cast"
input: "Equal"
attr {
@@ -4041,7 +4243,7 @@ meta_graphs {
}
}
node {
- name: "Const_2"
+ name: "Const_1"
op: "Const"
attr {
key: "_output_shapes"
@@ -4077,10 +4279,10 @@ meta_graphs {
}
}
node {
- name: "Mean"
+ name: "Mean_1"
op: "Mean"
- input: "Cast"
- input: "Const_2"
+ input: "Cast_1"
+ input: "Const_1"
attr {
key: "T"
value {
@@ -4110,16 +4312,6 @@ meta_graphs {
}
}
node {
- name: "init_all_tables"
- op: "NoOp"
- input: "^index_to_string/table_init"
- }
- node {
- name: "legacy_init_op"
- op: "NoOp"
- input: "^init_all_tables"
- }
- node {
name: "save/Const"
op: "Const"
attr {
@@ -4174,7 +4366,7 @@ meta_graphs {
dtype: DT_STRING
tensor_shape {
}
- string_val: "_temp_8390c48b96834292ab57050d0ae6959e/part"
+ string_val: "_temp_6ca9fa5171ed4237a2fbcc27277e2864/part"
}
}
}
@@ -4783,22 +4975,6 @@ meta_graphs {
version: V2
}
collection_def {
- key: "legacy_init_op"
- value {
- node_list {
- value: "legacy_init_op"
- }
- }
- }
- collection_def {
- key: "table_initializer"
- value {
- node_list {
- value: "index_to_string/table_init"
- }
- }
- }
- collection_def {
key: "train_op"
value {
node_list {
@@ -4810,8 +4986,8 @@ meta_graphs {
key: "trainable_variables"
value {
bytes_list {
- value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:0"
- value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:0"
+ value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0"
+ value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0"
}
}
}
@@ -4819,18 +4995,18 @@ meta_graphs {
key: "variables"
value {
bytes_list {
- value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:0"
- value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:0"
+ value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0"
+ value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0"
}
}
}
signature_def {
- key: "predict_images"
+ key: "serving_default"
value {
inputs {
- key: "images"
+ key: "x"
value {
- name: "x:0"
+ name: "Placeholder:0"
dtype: DT_FLOAT
tensor_shape {
dim {
@@ -4843,9 +5019,9 @@ meta_graphs {
}
}
outputs {
- key: "scores"
+ key: "y"
value {
- name: "y:0"
+ name: "add:0"
dtype: DT_FLOAT
tensor_shape {
dim {
@@ -4860,50 +5036,4 @@ meta_graphs {
method_name: "tensorflow/serving/predict"
}
}
- signature_def {
- key: "serving_default"
- value {
- inputs {
- key: "inputs"
- value {
- name: "tf_example:0"
- dtype: DT_STRING
- tensor_shape {
- unknown_rank: true
- }
- }
- }
- outputs {
- key: "classes"
- value {
- name: "index_to_string_Lookup:0"
- dtype: DT_STRING
- tensor_shape {
- dim {
- size: -1
- }
- dim {
- size: 10
- }
- }
- }
- }
- outputs {
- key: "scores"
- value {
- name: "TopKV2:0"
- dtype: DT_FLOAT
- tensor_shape {
- dim {
- size: -1
- }
- dim {
- size: 10
- }
- }
- }
- }
- method_name: "tensorflow/serving/classify"
- }
- }
}
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
new file mode 100644
index 00000000000..8474aa0a04c
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index
new file mode 100644
index 00000000000..cfcdac20409
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.data-00000-of-00001
deleted file mode 100644
index ba71c21fbe1..00000000000
--- a/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.data-00000-of-00001
+++ /dev/null
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.index
deleted file mode 100644
index 84f4593515a..00000000000
--- a/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.index
+++ /dev/null
Binary files differ
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 82e5d0cfe5b..3aa2d144f1f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.rule.*;
-import com.yahoo.tensor.Tensor;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import org.junit.Test;
+
import static org.junit.Assert.assertEquals;
/**
@@ -83,7 +88,7 @@ public class EvaluationTestCase {
tester.assertEvaluates(0, "sin(0)");
tester.assertEvaluates(1, "cos(0)");
tester.assertEvaluates(8, "pow(4/2,min(cos(0)*3,5))");
-
+
// Random feature (which is also a tensor function) (We expect to be able to parse it and look up a zero)
tester.assertEvaluates(0, "random(1)");
tester.assertEvaluates(0, "random(foo)");
@@ -152,7 +157,7 @@ public class EvaluationTestCase {
"tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
"!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }");
-
+
// -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence)
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }");
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
index ee2b1c147e3..ba0db4de5e1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
@@ -34,7 +34,7 @@ public class EvaluationTester {
}
// TODO: Test both bound and unbound indexed
- public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors,
+ public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors,
String ... tensorArgumentStrings) {
MapContext context = defaultContext.thawedCopy();
int argumentIndex = 0;
@@ -46,7 +46,7 @@ public class EvaluationTester {
argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString);
context.put("tensor" + (argumentIndex++), new TensorValue(argument));
}
- return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context,
+ return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context,
mappedTensors ? "Mapped tensors" : "Indexed tensors");
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
new file mode 100644
index 00000000000..dab42801d70
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
@@ -0,0 +1,114 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+
+import java.nio.FloatBuffer;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * @author bratseth
+ */
+public class Mnist_SoftmaxTestCase {
+
+ @Test
+ public void testImporting() {
+ String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
+ ImportResult result = new TensorFlowImporter().importModel(modelDir);
+
+ // Check logged messages
+ result.warnings().forEach(System.err::println);
+ assertEquals(0, result.warnings().size());
+
+ // Check arguments
+ assertEquals(1, result.arguments().size());
+ TensorType argument0 = result.arguments().get("Placeholder");
+ assertNotNull(argument0);
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
+
+ // Check constants
+ assertEquals(2, result.constants().size());
+
+ Tensor constant0 = result.constants().get("Variable");
+ assertNotNull(constant0);
+ assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
+ constant0.type());
+ assertEquals(7840, constant0.size());
+
+ Tensor constant1 = result.constants().get("Variable_1");
+ assertNotNull(constant1);
+ assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
+ constant1.type());
+ assertEquals(10, constant1.size());
+
+ // Check resulting Vespa expression
+ assertEquals(1, result.expressions().size());
+ assertEquals("y", result.expressions().get(0).getName());
+ assertEquals("" +
+ "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
+ "rename(constant(Variable_1), d0, d1), " +
+ "f(a,b)(a + b))",
+ toNonPrimitiveString(result.expressions().get(0)));
+
+ // Test execution
+ String signatureName = "serving_default";
+
+ assertEqualResult(modelDir, signatureName, "Variable/read");
+ assertEqualResult(modelDir, signatureName, "Variable_1/read");
+ // TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder");
+ assertEqualResult(modelDir, signatureName, "MatMul");
+ assertEqualResult(modelDir, signatureName, "add");
+ }
+
+ private void assertEqualResult(String modelDir, String signatureName, String operationName) {
+ ImportResult result = new TensorFlowImporter().importNode(modelDir, signatureName, operationName);
+
+ Tensor tfResult = tensorFlowExecute(modelDir, operationName);
+ Context context = contextFrom(result);
+ Tensor placeholder = placeholderArgument();
+ context.put("Placeholder", new TensorValue(placeholder));
+ Tensor vespaResult = result.expressions().get(0).evaluate(context).asTensor();
+ assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
+ }
+
+ private Tensor tensorFlowExecute(String modelDir, String operationName) {
+ SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
+ Session.Runner runner = model.session().runner();
+ org.tensorflow.Tensor placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784));
+ runner.feed("Placeholder", placeholder);
+ List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
+ assertEquals(1, results.size());
+ return new TensorConverter().toVespaTensor(results.get(0));
+ }
+
+ private Context contextFrom(ImportResult result) {
+ MapContext context = new MapContext();
+ result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ return context;
+ }
+
+ private String toNonPrimitiveString(RankingExpression expression) {
+ // toString on the wrapping expression will map to primitives, which is harder to read
+ return ((TensorFunctionNode)expression.getRoot()).function().toString();
+ }
+
+ private Tensor placeholderArgument() {
+ int size = 784;
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build());
+ for (int i = 0; i < size; i++)
+ b.cell(0, 0, i);
+ return b.build();
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
deleted file mode 100644
index aaf198a9e8f..00000000000
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
+++ /dev/null
@@ -1,79 +0,0 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.tensor.TensorType;
-import org.junit.Test;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.logging.Level;
-import java.util.stream.Collectors;
-
-import static org.junit.Assert.assertEquals;
-
-/**
- * @author bratseth
- */
-public class TensorFlowImporterTestCase {
-
- @Test
- public void testModel1() {
- List<NamedTensor> constants = new ArrayList<>();
- TestLogger logger = new TestLogger();
- List<RankingExpression> expressions =
- new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/", constants, logger);
-
- // Check constants
- assertEquals(2, constants.size());
-
- assertEquals("Variable", constants.get(0).name());
- assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
- constants.get(0).tensor().type());
- assertEquals(7840, constants.get(0).tensor().size());
-
- assertEquals("Variable_1", constants.get(1).name());
- assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
- constants.get(1).tensor().type());
- assertEquals(10, constants.get(1).tensor().size());
-
- // Check logged messages
- assertEquals(2, logger.messages().size());
- assertEquals("Skipping output 'TopKV2:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'TopKV2' is not supported",
- logger.messages().get(0));
- assertEquals("Skipping output 'index_to_string_Lookup:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'LookupTableFindV2' is not supported",
- logger.messages().get(1));
-
- // Check resulting Vespa expression
- assertEquals(1, expressions.size());
- assertEquals("scores", expressions.get(0).getName());
- assertEquals("" +
- "softmax(join(rename(matmul(x, rename(constant(Variable), (d1, d2), (d2, d3)), d2), d3, d2), " +
- "constant(Variable_1), " +
- "f(a,b)(a + b)), " +
- "d0)",
- toNonPrimitiveString(expressions.get(0)));
- }
-
- private String toNonPrimitiveString(RankingExpression expression) {
- // toString on the wrapping expression will map to primitives, which is harder to read
- return ((TensorFunctionNode)expression.getRoot()).function().toString();
- }
-
- private class TestLogger implements TensorFlowImporter.MessageLogger {
-
- private List<String> messages = new ArrayList<>();
-
- /** Returns the messages in sorted order */
- public List<String> messages() {
- return messages.stream().sorted().collect(Collectors.toList());
- }
-
- @Override
- public void log(Level level, String message) {
- messages.add(message);
- }
-
- }
-
-}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
index dde9d4bf21e..1960c1fe876 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
@@ -59,7 +59,7 @@ public class TensorConformanceTest {
try {
ObjectMapper mapper = new ObjectMapper();
JsonNode node = mapper.readTree(test);
-
+
if (node.has("num_tests")) {
Assert.assertEquals(node.get("num_tests").asInt(), count);
return true;
@@ -67,7 +67,7 @@ public class TensorConformanceTest {
if (!node.has("expression")) {
return true; // ignore
}
-
+
String expression = node.get("expression").asText();
MapContext context = getInput(node.get("inputs"));
Tensor expect = getTensor(node.get("result").get("expect").asText());