From dc0f70fac9167acf487453daf565636c675934df Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 9 Jan 2018 16:07:43 +0100 Subject: Basic TensorFlow integration This wil replace any occurrence of tensorflow(...) in ranking expressions with the corresponding translated expression. It is functional but these tings are outstanding - Propagate warnings - Import a model just once even if referred multiple times - Add constants as tensor files rather than config --- .../ConstantTensorTransformer.java | 76 - .../searchdefinition/FieldOperationApplier.java | 3 +- .../FieldOperationApplierForSearch.java | 2 +- .../FieldOperationApplierForStructs.java | 4 +- .../com/yahoo/searchdefinition/MacroInliner.java | 39 - .../com/yahoo/searchdefinition/MacroShadower.java | 65 - .../com/yahoo/searchdefinition/RankProfile.java | 36 +- .../yahoo/searchdefinition/TensorTransformer.java | 290 -- .../UnprocessingSearchBuilder.java | 1 - .../searchdefinition/UnrankedRankProfile.java | 1 + .../ConstantTensorTransformer.java | 76 + .../expressiontransforms/ExpressionTransforms.java | 34 + .../expressiontransforms/MacroInliner.java | 40 + .../expressiontransforms/MacroShadower.java | 66 + .../TensorFlowFeatureConverter.java | 114 + .../expressiontransforms/TensorTransformer.java | 293 ++ .../processing/PredicateProcessor.java | 2 +- .../mnist_softmax/mnist_sftmax_with_saving.py | 89 + .../mnist_softmax/saved/saved_model.pbtxt | 5039 ++++++++++++++++++++ .../saved/variables/variables.data-00000-of-00001 | Bin 0 -> 31400 bytes .../mnist_softmax/saved/variables/variables.index | Bin 0 -> 159 bytes .../processing/RankProfileSearchFixture.java | 58 + .../RankingExpressionWithTensorFlowTestCase.java | 119 + .../RankingExpressionWithTensorTestCase.java | 56 +- 24 files changed, 5964 insertions(+), 539 deletions(-) delete mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java delete mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/MacroInliner.java delete mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/MacroShadower.java delete mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java create mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java create mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java create mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java create mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java create mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java create mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java create mode 100644 config-model/src/test/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py create mode 100644 config-model/src/test/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt create mode 100644 config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 create mode 100644 config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.index create mode 100644 config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java create mode 100644 config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java (limited to 'config-model/src') diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java deleted file mode 100644 index c75864f81b7..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; - -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.NameNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; - -/** - * Transforms named references to constant tensors with the rank feature 'constant'. - * - * @author geirst - */ -class ConstantTensorTransformer extends ExpressionTransformer { - - public static final String CONSTANT = "constant"; - - private final Map constants; - private final Map rankPropertiesOutput; - - public ConstantTensorTransformer(Map constants, - Map rankPropertiesOutput) { - this.constants = constants; - this.rankPropertiesOutput = rankPropertiesOutput; - } - - @Override - public ExpressionNode transform(ExpressionNode node) { - if (node instanceof ReferenceNode) { - return transformFeature((ReferenceNode) node); - } else if (node instanceof CompositeNode) { - return transformChildren((CompositeNode) node); - } else { - return node; - } - } - - private ExpressionNode transformFeature(ReferenceNode node) { - if (!node.getArguments().isEmpty()) { - return transformArguments(node); - } else { - return transformConstantReference(node); - } - } - - private ExpressionNode transformArguments(ReferenceNode node) { - List arguments = node.getArguments().expressions(); - List transformedArguments = new ArrayList<>(arguments.size()); - for (ExpressionNode argument : arguments) { - transformedArguments.add(transform(argument)); - } - return node.setArguments(transformedArguments); - } - - private ExpressionNode transformConstantReference(ReferenceNode node) { - Value value = constants.get(node.getName()); - if (value == null || !(value instanceof TensorValue)) { - return node; - } - TensorValue tensorValue = (TensorValue)value; - String featureName = CONSTANT + "(" + node.getName() + ")"; - String tensorType = tensorValue.asTensor().type().toString(); - rankPropertiesOutput.put(featureName + ".value", tensorValue.toString()); - rankPropertiesOutput.put(featureName + ".type", tensorType); - return new ReferenceNode("constant", Arrays.asList(new NameNode(node.getName())), null); - } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplier.java b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplier.java index cdd089d9bf7..f9d71f03972 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplier.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplier.java @@ -6,9 +6,10 @@ import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; /** - * @author Einar M R Rosenvinge + * @author Einar M R Rosenvinge */ public class FieldOperationApplier { + public void process(SDDocumentType sdoc) { if (!sdoc.isStruct()) { apply(sdoc); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForSearch.java b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForSearch.java index addaa4bc632..1019b794cdd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForSearch.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForSearch.java @@ -5,7 +5,7 @@ import com.yahoo.document.Field; import com.yahoo.searchdefinition.document.SDDocumentType; /** - * @author Einar M R Rosenvinge + * @author Einar M R Rosenvinge */ public class FieldOperationApplierForSearch extends FieldOperationApplier { @Override diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java index 30fdcd01dd4..04b0fc6e331 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java @@ -11,9 +11,10 @@ import java.util.Iterator; import java.util.List; /** - * @author Einar M R Rosenvinge + * @author Einar M R Rosenvinge */ public class FieldOperationApplierForStructs extends FieldOperationApplier { + @Override public void process(SDDocumentType sdoc) { for (SDDocumentType type : sdoc.getAllTypes()) { @@ -45,4 +46,5 @@ public class FieldOperationApplierForStructs extends FieldOperationApplier { } } } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MacroInliner.java b/config-model/src/main/java/com/yahoo/searchdefinition/MacroInliner.java deleted file mode 100644 index 4702fac30a8..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MacroInliner.java +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; - -import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; - -import java.util.Map; - -/** - * Inlines macros in ranking expressions - * - * @author bratseth - */ -class MacroInliner extends ExpressionTransformer { - - private final Map macros; - - public MacroInliner(Map macros) { - this.macros = macros; - } - - @Override - public ExpressionNode transform(ExpressionNode node) { - if (node instanceof ReferenceNode) - return transformFeatureNode((ReferenceNode)node); - if (node instanceof CompositeNode) - return transformChildren((CompositeNode)node); - return node; - } - - private ExpressionNode transformFeatureNode(ReferenceNode feature) { - RankProfile.Macro macro = macros.get(feature.getName()); - if (macro == null) return feature; - return transform(macro.getRankingExpression().getRoot()); // inline recursively and return - } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MacroShadower.java b/config-model/src/main/java/com/yahoo/searchdefinition/MacroShadower.java deleted file mode 100644 index edf0ce69819..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MacroShadower.java +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.rule.*; -import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; - -import java.util.Map; - -/** - * Transforms function nodes to reference nodes if a macro shadows a built-in function. - * This has the effect of allowing macros to redefine built-in functions. - * Another effect is that we can more or less add built-in functions over time - * without fear of breaking existing users' macros with the same name. - * - * However, there is a (largish) caveat. If a user has a macro with a certain number - * of arguments, and we add in a built-in function with a different arity, - * this will cause parse errors as the Java parser gives precedence to - * built-in functions. - * - * @author lesters - */ -class MacroShadower extends ExpressionTransformer { - - private final Map macros; - - public MacroShadower(Map macros) { - this.macros = macros; - } - - @Override - public RankingExpression transform(RankingExpression expression) { - String name = expression.getName(); - ExpressionNode node = expression.getRoot(); - ExpressionNode result = transform(node); - return new RankingExpression(name, result); - } - - @Override - public ExpressionNode transform(ExpressionNode node) { - if (node instanceof FunctionNode) - return transformFunctionNode((FunctionNode) node); - if (node instanceof CompositeNode) - return transformChildren((CompositeNode)node); - return node; - } - - protected ExpressionNode transformFunctionNode(FunctionNode function) { - String name = function.getFunction().toString(); - RankProfile.Macro macro = macros.get(name); - if (macro == null) { - return transformChildren(function); - } - - int functionArity = function.getFunction().arity(); - int macroArity = macro.getFormalParams() != null ? macro.getFormalParams().size() : 0; - if (functionArity != macroArity) { - return transformChildren(function); - } - - ReferenceNode node = new ReferenceNode(name, function.children(), null); - return transformChildren(node); - } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 1021227b0e6..cf92d1f979b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -1,7 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.search.query.ranking.Diversity; +import com.yahoo.searchdefinition.expressiontransforms.ExpressionTransforms; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.FeatureList; @@ -9,13 +11,22 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.SetMembershipNode; -import com.yahoo.searchlib.rankingexpression.transform.ConstantDereferencer; -import com.yahoo.searchlib.rankingexpression.transform.Simplifier; -import com.yahoo.config.application.api.ApplicationPackage; -import java.io.*; -import java.util.*; +import java.io.File; +import java.io.IOException; +import java.io.Reader; +import java.io.Serializable; +import java.io.StringReader; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; /** * Represents a rank profile - a named set of ranking settings @@ -40,7 +51,7 @@ public class RankProfile implements Serializable, Cloneable { protected Set rankSettings = new java.util.LinkedHashSet<>(); /** The ranking expression to be used for first phase */ - private RankingExpression firstPhaseRanking= null; + private RankingExpression firstPhaseRanking= null; /** The ranking expression to be used for second phase */ private RankingExpression secondPhaseRanking = null; @@ -485,7 +496,7 @@ public class RankProfile implements Serializable, Cloneable { /** * Returns the string form of the second phase ranking expression. - * + * * @return string form of second phase ranking expression */ public String getSecondPhaseRankingString() { @@ -702,12 +713,7 @@ public class RankProfile implements Serializable, Cloneable { Map inlineMacros) { if (expression == null) return null; Map rankPropertiesOutput = new HashMap<>(); - expression = new ConstantDereferencer(constants).transform(expression); - expression = new ConstantTensorTransformer(constants, rankPropertiesOutput).transform(expression); - expression = new MacroInliner(inlineMacros).transform(expression); - expression = new MacroShadower(getMacros()).transform(expression); - expression = new TensorTransformer(this).transform(expression); - expression = new Simplifier().transform(expression); + expression = new ExpressionTransforms().transform(expression, this, constants, inlineMacros, rankPropertiesOutput); for (Map.Entry rankProperty : rankPropertiesOutput.entrySet()) { addRankProperty(rankProperty.getKey(), rankProperty.getValue()); } @@ -975,7 +981,7 @@ public class RankProfile implements Serializable, Cloneable { throw new IllegalArgumentException("match-phase did not set max-hits > 0"); } } - + } public static class TypeSettings { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java deleted file mode 100644 index 65176006a2a..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; - -import com.yahoo.searchdefinition.document.Attribute; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.Reduce; - -import java.util.List; -import java.util.Map; -import java.util.Optional; - -/** - * Transforms and simplifies tensor expressions. - * - * Currently transforms min(tensor,dim) and max(tensor,dim) to - * reduce(tensor,min/max,dim). This is necessary as the backend does - * not recognize these forms of min and max. - * - * @author lesters - */ -public class TensorTransformer extends ExpressionTransformer { - - private Search search; - private RankProfile rankprofile; - private Map macros; - - public TensorTransformer(RankProfile rankprofile) { - this.rankprofile = rankprofile; - this.search = rankprofile.getSearch(); - this.macros = rankprofile.getMacros(); - } - - @Override - public ExpressionNode transform(ExpressionNode node) { - if (node instanceof CompositeNode) { - node = transformChildren((CompositeNode) node); - } - if (node instanceof FunctionNode) { - node = transformFunctionNode((FunctionNode) node); - } - return node; - } - - private ExpressionNode transformFunctionNode(FunctionNode node) { - switch (node.getFunction()) { - case min: - case max: - return transformMaxAndMinFunctionNode(node); - } - return node; - } - - /** - * Transforms max and min functions if it can be proven that the first - * argument resolves to a tensor and the second argument is a valid - * dimension in the tensor. If these do not hold, the node will not - * be transformed. - * - * The test for whether or not the first argument resolves to a tensor - * is to evaluate that expression. All values used in the expression - * is bound to a context with dummy values with enough information to - * deduce tensor types. - * - * There is currently no guarantee that all cases will be found. For - * instance, if-statements are problematic. - */ - private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node) { - if (node.children().size() != 2) { - return node; - } - ExpressionNode arg1 = node.children().get(0); - Optional dimension = dimensionName(node.children().get(1)); - if (dimension.isPresent()) { - try { - Context context = buildContext(arg1); - Value value = arg1.evaluate(context); - if (isTensorWithDimension(value, dimension.get())) { - return replaceMaxAndMinFunction(node); - } - } catch (IllegalArgumentException e) { - // Thrown from evaluate if some variables are not bound, for - // instance for a backend rank feature. Means we don't have - // enough information to replace expression. - } - } - return node; - } - - private Optional dimensionName(ExpressionNode arg) { - if (arg instanceof ReferenceNode && ((ReferenceNode)arg).children().size() == 0) { - return Optional.of(((ReferenceNode) arg).getName()); - } - return Optional.empty(); - } - - private boolean isTensorWithDimension(Value value, String dimension) { - if (value instanceof TensorValue) { - Tensor tensor = ((TensorValue) value).asTensor(); - TensorType type = tensor.type(); - return type.dimensionNames().contains(dimension); - } - return false; - } - - private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) { - ExpressionNode arg1 = node.children().get(0); - ExpressionNode arg2 = node.children().get(1); - - TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1); - Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); - String dimension = ((ReferenceNode) arg2).getName(); - - return new TensorFunctionNode(new Reduce(expression, aggregator, dimension)); - } - - /** - * Creates an evaluation context by iterating through the expression tree, and - * adding dummy values with correct types to the context. - */ - private Context buildContext(ExpressionNode node) { - Context context = new MapContext(); - addRoot(node, context); - return context; - } - - private Value emptyStringValue() { - return new StringValue(""); - } - - private Value emptyDoubleValue() { - return new DoubleValue(0.0); - } - - private Value emptyTensorValue(TensorType type) { - Tensor empty = Tensor.Builder.of(type).build(); - return new TensorValue(empty); - } - - private void addRoot(ExpressionNode node, Context context) { - addChildren(node, context); - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) node; - addIfAttribute(referenceNode, context); - addIfConstant(referenceNode, context); - addIfQuery(referenceNode, context); - addIfTensorFrom(referenceNode, context); - addIfMacro(referenceNode, context); - } - } - - private void addChildren(ExpressionNode node, Context context) { - if (node instanceof CompositeNode) { - List children = ((CompositeNode) node).children(); - for (ExpressionNode child : children) { - addRoot(child, context); - } - } - } - - private void addIfAttribute(ReferenceNode node, Context context) { - if (!node.getName().equals("attribute")) { - return; - } - if (node.children().size() == 0) { - return; - } - String attribute = node.children().get(0).toString(); - Attribute a = search.getAttribute(attribute); - if (a == null) { - return; - } - Value v; - if (a.getType() == Attribute.Type.STRING) { - v = emptyStringValue(); - } else if (a.getType() == Attribute.Type.TENSOR) { - v = emptyTensorValue(a.tensorType().orElseThrow(RuntimeException::new)); - } else { - v = emptyDoubleValue(); - } - context.put(node.toString(), v); - } - - private void addIfConstant(ReferenceNode node, Context context) { - if (!node.getName().equals(ConstantTensorTransformer.CONSTANT)) { - return; - } - if (node.children().size() != 1) { - return; - } - ExpressionNode child = node.children().get(0); - while (child instanceof CompositeNode && ((CompositeNode) child).children().size() > 0) { - child = ((CompositeNode) child).children().get(0); - } - String name = child.toString(); - addIfConstantInRankProfile(name, node, context); - addIfConstantInRankingConstants(name, node, context); - } - - private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context) { - if (rankprofile.getConstants().containsKey(name)) { - context.put(node.toString(), rankprofile.getConstants().get(name)); - } - } - - private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context) { - for (RankingConstant rankingConstant : search.getRankingConstants()) { - if (rankingConstant.getName().equals(name)) { - context.put(node.toString(), emptyTensorValue(rankingConstant.getTensorType())); - } - } - } - - private void addIfQuery(ReferenceNode node, Context context) { - if (!node.getName().equals("query")) { - return; - } - if (node.children().size() != 1) { - return; - } - String name = node.children().get(0).toString(); - if (rankprofile.getQueryFeatureTypes().containsKey(name)) { - String type = rankprofile.getQueryFeatureTypes().get(name); - Value v; - if (type.contains("tensor")) { - v = emptyTensorValue(TensorType.fromSpec(type)); - } else if (type.equalsIgnoreCase("string")) { - v = emptyStringValue(); - } else { - v = emptyDoubleValue(); - } - context.put(node.toString(), v); - } - } - - private void addIfTensorFrom(ReferenceNode node, Context context) { - if (!node.getName().startsWith("tensorFrom")) { - return; - } - if (node.children().size() < 1 || node.children().size() > 2) { - return; - } - ExpressionNode source = node.children().get(0); - if (source instanceof CompositeNode && ((CompositeNode) source).children().size() > 0) { - source = ((CompositeNode) source).children().get(0); - } - String dimension = source.toString(); - if (node.children().size() == 2) { - dimension = node.children().get(1).toString(); - } - TensorType type = (new TensorType.Builder()).mapped(dimension).build(); - context.put(node.toString(), emptyTensorValue(type)); - } - - private void addIfMacro(ReferenceNode node, Context context) { - RankProfile.Macro macro = macros.get(node.getName()); - if (macro == null) { - return; - } - ExpressionNode root = macro.getRankingExpression().getRoot(); - Context macroContext = buildContext(root); - addMacroArguments(node, context, macro, macroContext); - Value value = root.evaluate(macroContext); - context.put(node.toString(), value); - } - - private void addMacroArguments(ReferenceNode node, Context context, RankProfile.Macro macro, Context macroContext) { - if (macro.getFormalParams().size() > 0 && node.children().size() > 0) { - for (int i = 0; i < macro.getFormalParams().size() && i < node.children().size(); ++i) { - String param = macro.getFormalParams().get(i); - ExpressionNode argumentExpression = node.children().get(i); - Value arg = argumentExpression.evaluate(context); - macroContext.put(param, arg); - } - } - } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/UnprocessingSearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/UnprocessingSearchBuilder.java index 1b292007ef3..b448005c6a5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/UnprocessingSearchBuilder.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/UnprocessingSearchBuilder.java @@ -10,7 +10,6 @@ import java.io.IOException; /** * A SearchBuilder that does not run the processing chain for searches - * */ public class UnprocessingSearchBuilder extends SearchBuilder { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/UnrankedRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/UnrankedRankProfile.java index d8f7e56539e..b58f696cbdf 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/UnrankedRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/UnrankedRankProfile.java @@ -25,4 +25,5 @@ public class UnrankedRankProfile extends RankProfile { this.setKeepRankCount(0); this.setRerankCount(0); } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java new file mode 100644 index 00000000000..e061ead465e --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java @@ -0,0 +1,76 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.NameNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +/** + * Transforms named references to constant tensors with the rank feature 'constant'. + * + * @author geirst + */ +public class ConstantTensorTransformer extends ExpressionTransformer { + + public static final String CONSTANT = "constant"; + + private final Map constants; + private final Map rankPropertiesOutput; + + public ConstantTensorTransformer(Map constants, + Map rankPropertiesOutput) { + this.constants = constants; + this.rankPropertiesOutput = rankPropertiesOutput; + } + + @Override + public ExpressionNode transform(ExpressionNode node) { + if (node instanceof ReferenceNode) { + return transformFeature((ReferenceNode) node); + } else if (node instanceof CompositeNode) { + return transformChildren((CompositeNode) node); + } else { + return node; + } + } + + private ExpressionNode transformFeature(ReferenceNode node) { + if (!node.getArguments().isEmpty()) { + return transformArguments(node); + } else { + return transformConstantReference(node); + } + } + + private ExpressionNode transformArguments(ReferenceNode node) { + List arguments = node.getArguments().expressions(); + List transformedArguments = new ArrayList<>(arguments.size()); + for (ExpressionNode argument : arguments) { + transformedArguments.add(transform(argument)); + } + return node.setArguments(transformedArguments); + } + + private ExpressionNode transformConstantReference(ReferenceNode node) { + Value value = constants.get(node.getName()); + if (value == null || !(value instanceof TensorValue)) { + return node; + } + TensorValue tensorValue = (TensorValue)value; + String featureName = CONSTANT + "(" + node.getName() + ")"; + String tensorType = tensorValue.asTensor().type().toString(); + rankPropertiesOutput.put(featureName + ".value", tensorValue.toString()); + rankPropertiesOutput.put(featureName + ".type", tensorType); + return new ReferenceNode("constant", Arrays.asList(new NameNode(node.getName())), null); + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java new file mode 100644 index 00000000000..ee5cccccb29 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java @@ -0,0 +1,34 @@ +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.transform.ConstantDereferencer; +import com.yahoo.searchlib.rankingexpression.transform.Simplifier; + +import java.util.Map; + +/** + * The transformations done on ranking expressions done at config time before passing them on to the Vespa + * engine for execution. + * + * @author bratseth + */ +public class ExpressionTransforms { + + public RankingExpression transform(RankingExpression expression, + RankProfile rankProfile, + Map constants, + Map inlineMacros, + Map rankPropertiesOutput) { + expression = new TensorFlowFeatureConverter(rankProfile).transform(expression); + expression = new ConstantDereferencer(constants).transform(expression); + expression = new ConstantTensorTransformer(constants, rankPropertiesOutput).transform(expression); + expression = new MacroInliner(inlineMacros).transform(expression); + expression = new MacroShadower(rankProfile.getMacros()).transform(expression); + expression = new TensorTransformer(rankProfile).transform(expression); + expression = new Simplifier().transform(expression); + return expression; + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java new file mode 100644 index 00000000000..a3933e6f8e2 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java @@ -0,0 +1,40 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; + +import java.util.Map; + +/** + * Inlines macros in ranking expressions + * + * @author bratseth + */ +public class MacroInliner extends ExpressionTransformer { + + private final Map macros; + + public MacroInliner(Map macros) { + this.macros = macros; + } + + @Override + public ExpressionNode transform(ExpressionNode node) { + if (node instanceof ReferenceNode) + return transformFeatureNode((ReferenceNode)node); + if (node instanceof CompositeNode) + return transformChildren((CompositeNode)node); + return node; + } + + private ExpressionNode transformFeatureNode(ReferenceNode feature) { + RankProfile.Macro macro = macros.get(feature.getName()); + if (macro == null) return feature; + return transform(macro.getRankingExpression().getRoot()); // inline recursively and return + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java new file mode 100644 index 00000000000..1d9769d0d78 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java @@ -0,0 +1,66 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.rule.*; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; + +import java.util.Map; + +/** + * Transforms function nodes to reference nodes if a macro shadows a built-in function. + * This has the effect of allowing macros to redefine built-in functions. + * Another effect is that we can more or less add built-in functions over time + * without fear of breaking existing users' macros with the same name. + * + * However, there is a (largish) caveat. If a user has a macro with a certain number + * of arguments, and we add in a built-in function with a different arity, + * this will cause parse errors as the Java parser gives precedence to + * built-in functions. + * + * @author lesters + */ +public class MacroShadower extends ExpressionTransformer { + + private final Map macros; + + public MacroShadower(Map macros) { + this.macros = macros; + } + + @Override + public RankingExpression transform(RankingExpression expression) { + String name = expression.getName(); + ExpressionNode node = expression.getRoot(); + ExpressionNode result = transform(node); + return new RankingExpression(name, result); + } + + @Override + public ExpressionNode transform(ExpressionNode node) { + if (node instanceof FunctionNode) + return transformFunctionNode((FunctionNode) node); + if (node instanceof CompositeNode) + return transformChildren((CompositeNode)node); + return node; + } + + protected ExpressionNode transformFunctionNode(FunctionNode function) { + String name = function.getFunction().toString(); + RankProfile.Macro macro = macros.get(name); + if (macro == null) { + return transformChildren(function); + } + + int functionArity = function.getFunction().arity(); + int macroArity = macro.getFormalParams() != null ? macro.getFormalParams().size() : 0; + if (functionArity != macroArity) { + return transformChildren(function); + } + + ReferenceNode node = new ReferenceNode(name, function.children(), null); + return transformChildren(node); + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java new file mode 100644 index 00000000000..e5886030d44 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -0,0 +1,114 @@ +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; + +import java.util.Map; +import java.util.Optional; + +/** + * Replaces instances of the tensorflow(model-path, signature, output) + * pseudofeature with the native Vespa ranking expression implementing + * the same computation. + * + * @author bratseth + */ +public class TensorFlowFeatureConverter extends ExpressionTransformer { + + private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); + private final RankProfile profile; + + public TensorFlowFeatureConverter(RankProfile profile) { + this.profile = profile; + } + + @Override + public ExpressionNode transform(ExpressionNode node) { + if (node instanceof ReferenceNode) + return transformFeature((ReferenceNode) node); + else if (node instanceof CompositeNode) + return super.transformChildren((CompositeNode) node); + else + return node; + } + + private ExpressionNode transformFeature(ReferenceNode feature) { + try { + if ( ! feature.getName().equals("tensorflow")) return feature; + + if (feature.getArguments().isEmpty()) + throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + + "the tensorflow model directory under [application]/models"); + + // Find the specified expression + ImportResult result = tensorFlowImporter.importModel(asString(feature.getArguments().expressions().get(0))); + ImportResult.Signature signature = chooseOrDefault("signatures", result.signatures(), + optionalArgument(1, feature.getArguments())); + String output = chooseOrDefault("outputs", signature.outputs(), + optionalArgument(2, feature.getArguments())); + + // Add all constants + result.constants().forEach((k, v) -> profile.addConstantTensor(k, new TensorValue(v))); + + return result.expressions().get(output).getRoot(); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not import tensorflow model from " + feature, e); + } + } + + /** + * Returns the specified, existing map value, or the only map value if no key is specified. + * Throws IllegalArgumentException in all other cases. + */ + private T chooseOrDefault(String valueDescription, Map map, Optional key) { + if ( ! key.isPresent()) { + if (map.size() == 0) + throw new IllegalArgumentException("No " + valueDescription + " are present"); + if (map.size() > 1) + throw new IllegalArgumentException("Model has multiple " + valueDescription + ", but no " + + valueDescription + " argument is specified"); + return map.values().stream().findFirst().get(); + } + else { + T value = map.get(key.get()); + if (value == null) + throw new IllegalArgumentException("Model does not have the specified " + + valueDescription + " '" + key.get() + "'"); + return value; + } + } + + private Optional optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + private String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java new file mode 100644 index 00000000000..70a7372dbe9 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java @@ -0,0 +1,293 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.document.Attribute; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Transforms and simplifies tensor expressions. + * + * Currently transforms min(tensor,dim) and max(tensor,dim) to + * reduce(tensor,min/max,dim). This is necessary as the backend does + * not recognize these forms of min and max. + * + * @author lesters + */ +public class TensorTransformer extends ExpressionTransformer { + + private Search search; + private RankProfile rankprofile; + private Map macros; + + public TensorTransformer(RankProfile rankprofile) { + this.rankprofile = rankprofile; + this.search = rankprofile.getSearch(); + this.macros = rankprofile.getMacros(); + } + + @Override + public ExpressionNode transform(ExpressionNode node) { + if (node instanceof CompositeNode) { + node = transformChildren((CompositeNode) node); + } + if (node instanceof FunctionNode) { + node = transformFunctionNode((FunctionNode) node); + } + return node; + } + + private ExpressionNode transformFunctionNode(FunctionNode node) { + switch (node.getFunction()) { + case min: + case max: + return transformMaxAndMinFunctionNode(node); + } + return node; + } + + /** + * Transforms max and min functions if it can be proven that the first + * argument resolves to a tensor and the second argument is a valid + * dimension in the tensor. If these do not hold, the node will not + * be transformed. + * + * The test for whether or not the first argument resolves to a tensor + * is to evaluate that expression. All values used in the expression + * is bound to a context with dummy values with enough information to + * deduce tensor types. + * + * There is currently no guarantee that all cases will be found. For + * instance, if-statements are problematic. + */ + private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node) { + if (node.children().size() != 2) { + return node; + } + ExpressionNode arg1 = node.children().get(0); + Optional dimension = dimensionName(node.children().get(1)); + if (dimension.isPresent()) { + try { + Context context = buildContext(arg1); + Value value = arg1.evaluate(context); + if (isTensorWithDimension(value, dimension.get())) { + return replaceMaxAndMinFunction(node); + } + } catch (IllegalArgumentException e) { + // Thrown from evaluate if some variables are not bound, for + // instance for a backend rank feature. Means we don't have + // enough information to replace expression. + } + } + return node; + } + + private Optional dimensionName(ExpressionNode arg) { + if (arg instanceof ReferenceNode && ((ReferenceNode)arg).children().size() == 0) { + return Optional.of(((ReferenceNode) arg).getName()); + } + return Optional.empty(); + } + + private boolean isTensorWithDimension(Value value, String dimension) { + if (value instanceof TensorValue) { + Tensor tensor = ((TensorValue) value).asTensor(); + TensorType type = tensor.type(); + return type.dimensionNames().contains(dimension); + } + return false; + } + + private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) { + ExpressionNode arg1 = node.children().get(0); + ExpressionNode arg2 = node.children().get(1); + + TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1); + Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); + String dimension = ((ReferenceNode) arg2).getName(); + + return new TensorFunctionNode(new Reduce(expression, aggregator, dimension)); + } + + /** + * Creates an evaluation context by iterating through the expression tree, and + * adding dummy values with correct types to the context. + */ + private Context buildContext(ExpressionNode node) { + Context context = new MapContext(); + addRoot(node, context); + return context; + } + + private Value emptyStringValue() { + return new StringValue(""); + } + + private Value emptyDoubleValue() { + return new DoubleValue(0.0); + } + + private Value emptyTensorValue(TensorType type) { + Tensor empty = Tensor.Builder.of(type).build(); + return new TensorValue(empty); + } + + private void addRoot(ExpressionNode node, Context context) { + addChildren(node, context); + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) node; + addIfAttribute(referenceNode, context); + addIfConstant(referenceNode, context); + addIfQuery(referenceNode, context); + addIfTensorFrom(referenceNode, context); + addIfMacro(referenceNode, context); + } + } + + private void addChildren(ExpressionNode node, Context context) { + if (node instanceof CompositeNode) { + List children = ((CompositeNode) node).children(); + for (ExpressionNode child : children) { + addRoot(child, context); + } + } + } + + private void addIfAttribute(ReferenceNode node, Context context) { + if (!node.getName().equals("attribute")) { + return; + } + if (node.children().size() == 0) { + return; + } + String attribute = node.children().get(0).toString(); + Attribute a = search.getAttribute(attribute); + if (a == null) { + return; + } + Value v; + if (a.getType() == Attribute.Type.STRING) { + v = emptyStringValue(); + } else if (a.getType() == Attribute.Type.TENSOR) { + v = emptyTensorValue(a.tensorType().orElseThrow(RuntimeException::new)); + } else { + v = emptyDoubleValue(); + } + context.put(node.toString(), v); + } + + private void addIfConstant(ReferenceNode node, Context context) { + if (!node.getName().equals(ConstantTensorTransformer.CONSTANT)) { + return; + } + if (node.children().size() != 1) { + return; + } + ExpressionNode child = node.children().get(0); + while (child instanceof CompositeNode && ((CompositeNode) child).children().size() > 0) { + child = ((CompositeNode) child).children().get(0); + } + String name = child.toString(); + addIfConstantInRankProfile(name, node, context); + addIfConstantInRankingConstants(name, node, context); + } + + private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context) { + if (rankprofile.getConstants().containsKey(name)) { + context.put(node.toString(), rankprofile.getConstants().get(name)); + } + } + + private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context) { + for (RankingConstant rankingConstant : search.getRankingConstants()) { + if (rankingConstant.getName().equals(name)) { + context.put(node.toString(), emptyTensorValue(rankingConstant.getTensorType())); + } + } + } + + private void addIfQuery(ReferenceNode node, Context context) { + if (!node.getName().equals("query")) { + return; + } + if (node.children().size() != 1) { + return; + } + String name = node.children().get(0).toString(); + if (rankprofile.getQueryFeatureTypes().containsKey(name)) { + String type = rankprofile.getQueryFeatureTypes().get(name); + Value v; + if (type.contains("tensor")) { + v = emptyTensorValue(TensorType.fromSpec(type)); + } else if (type.equalsIgnoreCase("string")) { + v = emptyStringValue(); + } else { + v = emptyDoubleValue(); + } + context.put(node.toString(), v); + } + } + + private void addIfTensorFrom(ReferenceNode node, Context context) { + if (!node.getName().startsWith("tensorFrom")) { + return; + } + if (node.children().size() < 1 || node.children().size() > 2) { + return; + } + ExpressionNode source = node.children().get(0); + if (source instanceof CompositeNode && ((CompositeNode) source).children().size() > 0) { + source = ((CompositeNode) source).children().get(0); + } + String dimension = source.toString(); + if (node.children().size() == 2) { + dimension = node.children().get(1).toString(); + } + TensorType type = (new TensorType.Builder()).mapped(dimension).build(); + context.put(node.toString(), emptyTensorValue(type)); + } + + private void addIfMacro(ReferenceNode node, Context context) { + RankProfile.Macro macro = macros.get(node.getName()); + if (macro == null) { + return; + } + ExpressionNode root = macro.getRankingExpression().getRoot(); + Context macroContext = buildContext(root); + addMacroArguments(node, context, macro, macroContext); + Value value = root.evaluate(macroContext); + context.put(node.toString(), value); + } + + private void addMacroArguments(ReferenceNode node, Context context, RankProfile.Macro macro, Context macroContext) { + if (macro.getFormalParams().size() > 0 && node.children().size() > 0) { + for (int i = 0; i < macro.getFormalParams().size() && i < node.children().size(); ++i) { + String param = macro.getFormalParams().get(i); + ExpressionNode argumentExpression = node.children().get(i); + Value arg = argumentExpression.evaluate(context); + macroContext.put(param, arg); + } + } + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/PredicateProcessor.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/PredicateProcessor.java index 450c24d8e3e..4b9b090cdc5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/PredicateProcessor.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/PredicateProcessor.java @@ -24,7 +24,7 @@ import java.util.List; /** * Validates the predicate fields. * - * @author Lester Solbakken + * @author Lester Solbakken */ public class PredicateProcessor extends Processor { diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/config-model/src/test/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py new file mode 100644 index 00000000000..a1861a1c981 --- /dev/null +++ b/config-model/src/test/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py @@ -0,0 +1,89 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A very simple MNIST classifier. + +See extensive documentation at +https://www.tensorflow.org/get_started/mnist/beginners +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +from tensorflow.examples.tutorials.mnist import input_data + +import tensorflow as tf + +FLAGS = None + + +def main(_): + # Import data + mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) + + # Create the model + x = tf.placeholder(tf.float32, [None, 784]) + W = tf.Variable(tf.zeros([784, 10])) + b = tf.Variable(tf.zeros([10])) + y = tf.matmul(x, W) + b + + # Define loss and optimizer + y_ = tf.placeholder(tf.float32, [None, 10]) + + # The raw formulation of cross-entropy, + # + # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), + # reduction_indices=[1])) + # + # can be numerically unstable. + # + # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw + # outputs of 'y', and then average across the batch. + cross_entropy = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) + train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) + + sess = tf.InteractiveSession() + tf.global_variables_initializer().run() + # Train + for _ in range(1000): + batch_xs, batch_ys = mnist.train.next_batch(100) + sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) + + # Test trained model + correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + print(sess.run(accuracy, feed_dict={x: mnist.test.images, + y_: mnist.test.labels})) + + # Save the model + export_path = "saved" + print('Exporting trained model to ', export_path) + builder = tf.saved_model.builder.SavedModelBuilder(export_path) + signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':x}, outputs = {'y':y}) + builder.add_meta_graph_and_variables(sess, + [tf.saved_model.tag_constants.SERVING], + signature_def_map={'serving_default':signature}) + builder.save(as_text=True) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', + help='Directory for storing input data') + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/config-model/src/test/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt new file mode 100644 index 00000000000..8100dfd594d --- /dev/null +++ b/config-model/src/test/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt @@ -0,0 +1,5039 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Add" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } + } + op { + name: "ApplyGradientDescent" + input_arg { + name: "var" + type_attr: "T" + is_ref: true + } + input_arg { + name: "alpha" + type_attr: "T" + } + input_arg { + name: "delta" + type_attr: "T" + } + output_arg { + name: "out" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: false + } + } + } + op { + name: "ArgMax" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dimension" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "output_type" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "output_type" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Assign" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + } + attr { + name: "validate_shape" + type: "bool" + default_value { + b: true + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: true + } + } + allows_uninitialized_input: true + } + op { + name: "BroadcastGradientArgs" + input_arg { + name: "s0" + type_attr: "T" + } + input_arg { + name: "s1" + type_attr: "T" + } + output_arg { + name: "r0" + type_attr: "T" + } + output_arg { + name: "r1" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Cast" + input_arg { + name: "x" + type_attr: "SrcT" + } + output_arg { + name: "y" + type_attr: "DstT" + } + attr { + name: "SrcT" + type: "type" + } + attr { + name: "DstT" + type: "type" + } + } + op { + name: "ConcatV2" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + input_arg { + name: "axis" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 2 + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "Equal" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type: DT_BOOL + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_QUINT8 + type: DT_QINT8 + type: DT_QINT32 + type: DT_STRING + type: DT_BOOL + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "ExpandDims" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dim" + type_attr: "Tdim" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tdim" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Fill" + input_arg { + name: "dims" + type: DT_INT32 + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "FloorDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Identity" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "MatMul" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "b" + type_attr: "T" + } + output_arg { + name: "product" + type_attr: "T" + } + attr { + name: "transpose_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "transpose_b" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Maximum" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_commutative: true + } + op { + name: "Mean" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "MergeV2Checkpoints" + input_arg { + name: "checkpoint_prefixes" + type: DT_STRING + } + input_arg { + name: "destination_prefix" + type: DT_STRING + } + attr { + name: "delete_old_dirs" + type: "bool" + default_value { + b: true + } + } + is_stateful: true + } + op { + name: "Mul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "NoOp" + } + op { + name: "Pack" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "axis" + type: "int" + default_value { + i: 0 + } + } + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "Prod" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RealDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Reshape" + input_arg { + name: "tensor" + type_attr: "T" + } + input_arg { + name: "shape" + type_attr: "Tshape" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshape" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RestoreV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + output_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "SaveV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + input_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "Shape" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "ShardedFilename" + input_arg { + name: "basename" + type: DT_STRING + } + input_arg { + name: "shard" + type: DT_INT32 + } + input_arg { + name: "num_shards" + type: DT_INT32 + } + output_arg { + name: "filename" + type: DT_STRING + } + } + op { + name: "Slice" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "begin" + type_attr: "Index" + } + input_arg { + name: "size" + type_attr: "Index" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Index" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "SoftmaxCrossEntropyWithLogits" + input_arg { + name: "features" + type_attr: "T" + } + input_arg { + name: "labels" + type_attr: "T" + } + output_arg { + name: "loss" + type_attr: "T" + } + output_arg { + name: "backprop" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + } + op { + name: "StringJoin" + input_arg { + name: "inputs" + type: DT_STRING + number_attr: "N" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "separator" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "Sub" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Sum" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Tile" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "multiples" + type_attr: "Tmultiples" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tmultiples" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "VariableV2" + output_arg { + name: "ref" + type_attr: "dtype" + is_ref: true + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true + } + op { + name: "ZerosLike" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + } + tags: "serve" + tensorflow_version: "1.4.1" + tensorflow_git_version: "v1.4.0-19-ga52c8d9b01" + } + graph_def { + node { + name: "Placeholder" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + node { + name: "zeros" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "Variable" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "Variable/Assign" + op: "Assign" + input: "Variable" + input: "zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "Variable/read" + op: "Identity" + input: "Variable" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "zeros_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "Variable_1" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "Variable_1/Assign" + op: "Assign" + input: "Variable_1" + input: "zeros_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "Variable_1/read" + op: "Identity" + input: "Variable_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "MatMul" + op: "MatMul" + input: "Placeholder" + input: "Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "add" + op: "Add" + input: "MatMul" + input: "Variable_1/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "Placeholder_1" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + node { + name: "Rank" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape" + op: "Shape" + input: "add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Rank_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape_1" + op: "Shape" + input: "add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Sub/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice/begin" + op: "Pack" + input: "Sub" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice/size" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "concat/values_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat/axis" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + } + node { + name: "Reshape" + op: "Reshape" + input: "add" + input: "concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Rank_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape_2" + op: "Shape" + input: "Placeholder_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Sub_1/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice_1/size" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "concat_1/values_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat_1/axis" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + } + node { + name: "Reshape_1" + op: "Reshape" + input: "Placeholder_1" + input: "concat_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape" + input: "Reshape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Sub_2/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice_2/begin" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Reshape_2" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Mean" + op: "Mean" + input: "Reshape_2" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "gradients/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape_1" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Mean_grad/Shape_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Mean_grad/Const_1" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_2" + input: "gradients/Mean_grad/Const_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Mean_grad/Maximum/y" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Reshape_2_grad/Shape" + op: "Shape" + input: "SoftmaxCrossEntropyWithLogits" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Reshape_2_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_2_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/zeros_like" + op: "ZerosLike" + input: "SoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_2_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "SoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Reshape_grad/Shape" + op: "Shape" + input: "add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Reshape_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/add_grad/Shape" + op: "Shape" + input: "MatMul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/add_grad/Shape_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } + } + node { + name: "gradients/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_grad/Shape" + input: "gradients/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/add_grad/Sum" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/add_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/add_grad/Reshape" + op: "Reshape" + input: "gradients/add_grad/Sum" + input: "gradients/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/add_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_grad/Sum_1" + input: "gradients/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_grad/Reshape" + input: "^gradients/add_grad/Reshape_1" + } + node { + name: "gradients/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_grad/Reshape" + input: "^gradients/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_grad/Reshape_1" + input: "^gradients/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/add_grad/tuple/control_dependency" + input: "Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } + } + node { + name: "gradients/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Placeholder" + input: "gradients/add_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "gradients/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/MatMul_1" + } + node { + name: "gradients/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "gradients/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_grad/MatMul_1" + input: "^gradients/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "GradientDescent/learning_rate" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + } + node { + name: "GradientDescent/update_Variable/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "Variable" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent/update_Variable_1/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "Variable_1" + input: "GradientDescent/learning_rate" + input: "gradients/add_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_Variable/ApplyGradientDescent" + input: "^GradientDescent/update_Variable_1/ApplyGradientDescent" + } + node { + name: "init" + op: "NoOp" + input: "^Variable/Assign" + input: "^Variable_1/Assign" + } + node { + name: "ArgMax/dimension" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "ArgMax" + op: "ArgMax" + input: "add" + input: "ArgMax/dimension" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_type" + value { + type: DT_INT64 + } + } + } + node { + name: "ArgMax_1/dimension" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "ArgMax_1" + op: "ArgMax" + input: "Placeholder_1" + input: "ArgMax_1/dimension" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_type" + value { + type: DT_INT64 + } + } + } + node { + name: "Equal" + op: "Equal" + input: "ArgMax" + input: "ArgMax_1" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Cast_1" + op: "Cast" + input: "Equal" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_BOOL + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Const_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Mean_1" + op: "Mean" + input: "Cast_1" + input: "Const_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "save/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "model" + } + } + } + } + node { + name: "save/StringJoin/inputs_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "_temp_6ca9fa5171ed4237a2fbcc27277e2864/part" + } + } + } + } + node { + name: "save/StringJoin" + op: "StringJoin" + input: "save/Const" + input: "save/StringJoin/inputs_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "separator" + value { + s: "" + } + } + } + node { + name: "save/num_shards" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "save/ShardedFilename/shard" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "save/ShardedFilename" + op: "ShardedFilename" + input: "save/StringJoin" + input: "save/ShardedFilename/shard" + input: "save/num_shards" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/SaveV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "Variable" + string_val: "Variable_1" + } + } + } + } + node { + name: "save/SaveV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "" + string_val: "" + } + } + } + } + node { + name: "save/SaveV2" + op: "SaveV2" + input: "save/ShardedFilename" + input: "save/SaveV2/tensor_names" + input: "save/SaveV2/shape_and_slices" + input: "Variable" + input: "Variable_1" + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + } + node { + name: "save/control_dependency" + op: "Identity" + input: "save/ShardedFilename" + input: "^save/SaveV2" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@save/ShardedFilename" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/MergeV2Checkpoints/checkpoint_prefixes" + op: "Pack" + input: "save/ShardedFilename" + input: "^save/control_dependency" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "save/MergeV2Checkpoints" + op: "MergeV2Checkpoints" + input: "save/MergeV2Checkpoints/checkpoint_prefixes" + input: "save/Const" + attr { + key: "delete_old_dirs" + value { + b: true + } + } + } + node { + name: "save/Identity" + op: "Identity" + input: "save/Const" + input: "^save/control_dependency" + input: "^save/MergeV2Checkpoints" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/RestoreV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "Variable" + } + } + } + } + node { + name: "save/RestoreV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2/tensor_names" + input: "save/RestoreV2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign" + op: "Assign" + input: "Variable" + input: "save/RestoreV2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_1/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "Variable_1" + } + } + } + } + node { + name: "save/RestoreV2_1/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_1" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_1/tensor_names" + input: "save/RestoreV2_1/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_1" + op: "Assign" + input: "Variable_1" + input: "save/RestoreV2_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/restore_shard" + op: "NoOp" + input: "^save/Assign" + input: "^save/Assign_1" + } + node { + name: "save/restore_all" + op: "NoOp" + input: "^save/restore_shard" + } + versions { + producer: 24 + } + } + saver_def { + filename_tensor_name: "save/Const:0" + save_tensor_name: "save/Identity:0" + restore_op_name: "save/restore_all" + max_to_keep: 5 + sharded: true + keep_checkpoint_every_n_hours: 10000.0 + version: V2 + } + collection_def { + key: "train_op" + value { + node_list { + value: "GradientDescent" + } + } + } + collection_def { + key: "trainable_variables" + value { + bytes_list { + value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0" + value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0" + } + } + } + collection_def { + key: "variables" + value { + bytes_list { + value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0" + value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0" + } + } + } + signature_def { + key: "serving_default" + value { + inputs { + key: "x" + value { + name: "Placeholder:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + outputs { + key: "y" + value { + name: "add:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } +} diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..8474aa0a04c Binary files /dev/null and b/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 differ diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.index new file mode 100644 index 00000000000..cfcdac20409 Binary files /dev/null and b/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.index differ diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java new file mode 100644 index 00000000000..e71a627d7db --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -0,0 +1,58 @@ +package com.yahoo.searchdefinition.processing; + +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.SearchBuilder; +import com.yahoo.searchdefinition.parser.ParseException; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * Helper class for setting up and asserting over a Search instance with a rank profile given literally + * in the search definition language. + * + * @author geirst + */ +class RankProfileSearchFixture { + + private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + private Search search; + + RankProfileSearchFixture(String rankProfiles) throws ParseException { + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + String sdContent = "search test {\n" + + " document test {\n" + + " }\n" + + rankProfiles + + "\n" + + "}"; + builder.importString(sdContent); + builder.build(); + search = builder.getSearch(); + } + + public void assertFirstPhaseExpression(String expExpression, String rankProfile) { + assertEquals(expExpression, rankProfile(rankProfile).getFirstPhaseRanking().getRoot().toString()); + } + + public void assertSecondPhaseExpression(String expExpression, String rankProfile) { + assertEquals(expExpression, rankProfile(rankProfile).getSecondPhaseRanking().getRoot().toString()); + } + + public void assertRankProperty(String expValue, String name, String rankProfile) { + List rankPropertyList = rankProfile(rankProfile).getRankPropertyMap().get(name); + assertEquals(1, rankPropertyList.size()); + assertEquals(expValue, rankPropertyList.get(0).getValue()); + } + + public void assertMacro(String expExpression, String macroName, String rankProfile) { + assertEquals(expExpression, rankProfile(rankProfile).getMacros().get(macroName).getRankingExpression().getRoot().toString()); + } + + public RankProfile rankProfile(String rankProfile) { + return rankProfileRegistry.getRankProfile(search, rankProfile).compile(); + } +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java new file mode 100644 index 00000000000..5ad85ac872c --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -0,0 +1,119 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.yolean.Exceptions; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; + +/** + * @author bratseth + */ +public class RankingExpressionWithTensorFlowTestCase { + + private final String modelDirectory = "src/test/integration/tensorflow/mnist_softmax/saved"; + private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(Variable_1), d0, d1), f(a,b)(a + b))"; + + @Test + public void testMinimalTensorFlowReference() throws ParseException { + RankProfileSearchFixture search = new RankProfileSearchFixture( + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('" + modelDirectory + "')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + + Tensor variable_1 = search.rankProfile("my_profile").getConstants().get("Variable_1").asTensor(); + assertNotNull("Variable_1 is imported", variable_1); + assertEquals(10, variable_1.size()); + + Tensor variable = search.rankProfile("my_profile").getConstants().get("Variable").asTensor(); + assertNotNull("Variable is imported", variable); + assertEquals(7840, variable.size()); + } + + @Test + public void testNestedTensorFlowReference() throws ParseException { + RankProfileSearchFixture search = new RankProfileSearchFixture( + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: 5 + sum(tensorflow('" + modelDirectory + "'))" + + " }\n" + + " }"); + search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); + + Tensor variable_1 = search.rankProfile("my_profile").getConstants().get("Variable_1").asTensor(); + assertNotNull("Variable_1 is imported", variable_1); + assertEquals(10, variable_1.size()); + + Tensor variable = search.rankProfile("my_profile").getConstants().get("Variable").asTensor(); + assertNotNull("Variable is imported", variable); + assertEquals(7840, variable.size()); + } + + @Test + public void testTensorFlowReferenceSpecifyingSignature() throws ParseException { + RankProfileSearchFixture search = new RankProfileSearchFixture( + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('" + modelDirectory + "', 'serving_default')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + } + + @Test + public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException { + RankProfileSearchFixture search = new RankProfileSearchFixture( + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'y')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + } + + @Test + public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException { + try { + RankProfileSearchFixture search = new RankProfileSearchFixture( + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('" + modelDirectory + "', 'serving_defaultz')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + fail("Expecting exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Rank profile 'my_profile' is invalid: Could not import tensorflow model from tensorflow('" + + modelDirectory + "','serving_defaultz'): Model does not have the specified signatures 'serving_defaultz'", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException { + try { + RankProfileSearchFixture search = new RankProfileSearchFixture( + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'x')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + fail("Expecting exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Rank profile 'my_profile' is invalid: Could not import tensorflow model from tensorflow('" + + modelDirectory + "','serving_default','x'): Model does not have the specified outputs 'x'", + Exceptions.toMessageString(expected)); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java index 4dcf7523fd0..dba2bdbfbbf 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java @@ -1,61 +1,19 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; -import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchdefinition.RankProfileRegistry; -import com.yahoo.searchdefinition.Search; -import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import java.util.List; - -import static org.junit.Assert.assertEquals; - /** * @author geirst */ public class RankingExpressionWithTensorTestCase { - private static class SearchFixture { - RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - Search search; - SearchFixture(String rankProfiles) throws ParseException { - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); - String sdContent = "search test {\n" + - " document test {\n" + - " }\n" + - rankProfiles + - "\n" + - "}"; - builder.importString(sdContent); - builder.build(); - search = builder.getSearch(); - } - public void assertFirstPhaseExpression(String expExpression, String rankProfile) { - assertEquals(expExpression, getRankProfile(rankProfile).getFirstPhaseRanking().getRoot().toString()); - } - public void assertSecondPhaseExpression(String expExpression, String rankProfile) { - assertEquals(expExpression, getRankProfile(rankProfile).getSecondPhaseRanking().getRoot().toString()); - } - public void assertRankProperty(String expValue, String name, String rankProfile) { - List rankPropertyList = getRankProfile(rankProfile).getRankPropertyMap().get(name); - assertEquals(1, rankPropertyList.size()); - assertEquals(expValue, rankPropertyList.get(0).getValue()); - } - public void assertMacro(String expExpression, String macroName, String rankProfile) { - assertEquals(expExpression, getRankProfile(rankProfile).getMacros().get(macroName).getRankingExpression().getRoot().toString()); - } - private RankProfile getRankProfile(String rankProfile) { - return rankProfileRegistry.getRankProfile(search, rankProfile).compile(); - } - } - @Test public void requireThatSingleLineConstantTensorAndTypeCanBeParsed() throws ParseException { - SearchFixture f = new SearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " first-phase {\n" + " expression: sum(my_tensor)\n" + @@ -74,7 +32,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatMultiLineConstantTensorAndTypeCanBeParsed() throws ParseException { - SearchFixture f = new SearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " first-phase {\n" + " expression: sum(my_tensor)\n" + @@ -96,7 +54,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatConstantTensorsCanBeUsedInSecondPhaseExpression() throws ParseException { - SearchFixture f = new SearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " second-phase {\n" + " expression: sum(my_tensor)\n" + @@ -114,7 +72,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatConstantTensorsCanBeUsedInInheritedRankProfile() throws ParseException { - SearchFixture f = new SearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile parent {\n" + " constants {\n" + " my_tensor {\n" + @@ -134,7 +92,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatConstantTensorsCanBeUsedInMacro() throws ParseException { - SearchFixture f = new SearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " macro my_macro() {\n" + " expression: sum(my_tensor)\n" + @@ -156,7 +114,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatCombinationOfConstantTensorsAndConstantValuesCanBeUsed() throws ParseException { - SearchFixture f = new SearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " first-phase {\n" + " expression: my_number_1 + sum(my_tensor) + my_number_2\n" + @@ -181,7 +139,7 @@ public class RankingExpressionWithTensorTestCase { public void requireThatInvalidTensorTypeSpecThrowsException() throws ParseException { exception.expect(IllegalArgumentException.class); exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: Failed parsing element 'x' in type spec 'tensor(x)'"); - new SearchFixture( + new RankProfileSearchFixture( " rank-profile my_profile {\n" + " constants {\n" + " my_tensor {\n" + -- cgit v1.2.3