diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2018-03-07 17:53:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-03-07 17:53:08 +0100 |
commit | f435d9c6fe2bef62172aa1f18948459b402d0328 (patch) | |
tree | ec031359c8b3d414a24ed2ce92ee3cea5a979f22 | |
parent | 25ff6f44faab887decc871e42b744fc5c06c1178 (diff) | |
parent | de5472f3761f666aa5d990f0d49322f7f6425a76 (diff) |
Merge pull request #5241 from vespa-engine/bratseth/tf-constants-in-parent-doc
Bratseth/tf constants in parent doc
18 files changed, 205 insertions, 80 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index e81d22cefe9..2c177633590 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -9,9 +9,11 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -51,6 +53,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.logging.Logger; +import java.util.stream.Collectors; /** * Replaces instances of the tensorflow(model-path, signature, output) @@ -85,10 +88,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil try { ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); - if (store.hasStoredModel()) - return transformFromStoredModel(store, context.rankProfile()); - else // not converted yet - access TensorFlow model files + if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles()); + else + return transformFromStoredModel(store, context.rankProfile()); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); @@ -101,16 +104,21 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), k -> tensorFlowImporter.importModel(store.tensorFlowModelDir())); + // Add constants + Set<String> constantsReplacedByMacros = new HashSet<>(); + model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); + model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, + constantsReplacedByMacros, k, v)); + // Find the specified expression Signature signature = chooseSignature(model, store.arguments().signature()); String output = chooseOutput(signature, store.arguments().output()); RankingExpression expression = model.expressions().get(output); + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); verifyRequiredMacros(expression, model.requiredMacros(), profile, queryProfiles); store.writeConverted(expression); - model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); - model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, k, v)); - model.macros().forEach((k, v) -> transformMacro(store, profile, k, v)); + model.macros().forEach((k, v) -> transformMacro(store, profile, constantsReplacedByMacros, k, v)); return expression.getRoot(); } @@ -189,17 +197,35 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil profile.addConstant(constantName, asValue(constantValue)); } - private void transformLargeConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { - Path constantPath = store.writeLargeConstant(constantName, constantValue); + private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, + Set<String> constantsReplacedByMacros, + String constantName, Tensor constantValue) { + RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); + if (macroOverridingConstant != null) { + TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); + if ( ! macroType.equals(constantValue.type())) + throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + + "The required type of this is " + constantValue.type() + + ", but the macro returns " + macroType); + constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later + } + else { + + Path constantPath = store.writeLargeConstant(constantName, constantValue); - if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { - log.info("Adding constant '" + constantName + "' of type " + constantValue.type()); - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); + if (!profile.getSearch().getRankingConstants().containsKey(constantName)) { + log.info("Adding constant '" + constantName + "' of type " + constantValue.type()); + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); + } } } - private void transformMacro(ModelStore store, RankProfile profile, String macroName, RankingExpression expression) { + private void transformMacro(ModelStore store, RankProfile profile, + Set<String> constantsReplacedByMacros, + String macroName, RankingExpression expression) { + + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); store.writeMacro(macroName, expression); addMacroToProfile(profile, macroName, expression); } @@ -312,6 +338,35 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil return node; } + /** + * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. + * This method does that for the given expression and returns the result. + */ + private RankingExpression replaceConstantsByMacros(RankingExpression expression, + Set<String> constantsReplacedByMacros) { + if (constantsReplacedByMacros.isEmpty()) return expression; + return new RankingExpression(expression.getName(), + replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); + } + + private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { + if (node instanceof ReferenceNode) { + Reference reference = ((ReferenceNode)node).reference(); + if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { + String argument = reference.simpleArgument().get(); + if (constantsReplacedByMacros.contains(argument)) + return new ReferenceNode(argument); + } + } + if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above + CompositeNode composite = (CompositeNode)node; + return composite.setChildren(composite.children().stream() + .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) + .collect(Collectors.toList())); + } + return node; + } + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, List<String> reduceDimensions) { return new TensorFunctionNode(new Reduce(function, Reduce.Aggregator.sum, reduceDimensions)); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 8e404e72ec7..06912a980a8 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -42,7 +42,7 @@ import static org.junit.Assert.*; public class RankingExpressionWithTensorFlowTestCase { private final Path applicationDir = Path.fromString("src/test/integration/tensorflow/"); - private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"layer_Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"layer_Variable_1_read\"), f(a,b)(a + b))"; + private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b))"; @After public void removeGeneratedConstantTensorFiles() { @@ -252,8 +252,51 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test + public void testImportingFromStoredExpressionsWithMacroOverridingConstant() throws IOException { + String rankProfile = + " rank-profile my_profile {\n" + + " macro Placeholder() {\n" + + " expression: tensor(d0[2],d1[784])(0.0)\n" + + " }\n" + + " macro layer_Variable_read() {\n" + + " expression: tensor(d1[10],d2[784])(0.0)\n" + + " }\n" + + " first-phase {\n" + + " expression: tensorflow('mnist_softmax/saved')" + + " }\n" + + " }"; + + + String vespaExpressionWithoutConstant = + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), layer_Variable_read, f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b))"; + RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); + search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + + assertNull("Constant overridden by macro is not added", + search.search().getRankingConstants().get("layer_Variable_read")); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication); + searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + assertNull("Constant overridden by macro is not added", + searchFromStored.search().getRankingConstants().get("layer_Variable_read")); + assertLargeConstant("layer_Variable_1_read", searchFromStored, Optional.of(10L)); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + } + + @Test public void testTensorFlowReduceBatchDimension() { - final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(\"layer_Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"layer_Variable_1_read\"), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(expression, "my_profile"); @@ -263,9 +306,9 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testMacroGeneration() { - final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))"; - final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))"; - final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))"; + final String expression = "join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b))"; + final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist/saved')"); @@ -276,9 +319,9 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { - final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))"; - final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))"; - final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))"; + final String expression = "join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b))"; + final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", @@ -383,6 +426,16 @@ public class RankingExpressionWithTensorFlowTestCase { } } + private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) { + try { + return new RankProfileSearchFixture(application, application.getQueryProfiles(), + rankProfile, null, null); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + private static class StoringApplicationPackage extends MockApplicationPackage { private final File root; diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java index 6ce08c75dd4..697cf910719 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java @@ -330,14 +330,12 @@ public class FastSearcher extends VespaBackEndSearcher { result.hits().addError(ErrorMessage.createBackendCommunicationError("Error filling hits with summary fields, source: " + getName())); return; } - if (skippedHits==0 && packetWrapper != null) { + if (skippedHits == 0 && packetWrapper != null) { cacheControl.updateCacheEntry(cacheKey, query, packetKeys, receivedPackets); } - if ( skippedHits>0 ) { - getLogger().info("Could not fill summary '" + summaryClass + "' for " + skippedHits + " hits for query: " + result.getQuery()); + if ( skippedHits > 0 ) result.hits().addError(com.yahoo.search.result.ErrorMessage.createEmptyDocsums("Missing hit data for summary '" + summaryClass + "' for " + skippedHits + " hits")); - } result.analyzeHits(); if (query.getTraceLevel() >= 3) { diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java index 7a390ac279b..c61b043a36b 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java @@ -245,7 +245,7 @@ public class Dispatcher extends AbstractComponent { } Inspector summaries = new SlimeAdapter(root.field("docsums")); - if ( ! summaries.valid() && !hasErrors) + if ( ! summaries.valid() && ! hasErrors) throw new IllegalArgumentException("Expected a Slime root object containing a 'docsums' field"); for (int i = 0; i < hits.size(); i++) { fill(hits.get(i), summaries.entry(i).field("docsum")); diff --git a/container-search/src/main/java/com/yahoo/search/federation/sourceref/SourceRefResolver.java b/container-search/src/main/java/com/yahoo/search/federation/sourceref/SourceRefResolver.java index f690ad10050..e14c46056b7 100644 --- a/container-search/src/main/java/com/yahoo/search/federation/sourceref/SourceRefResolver.java +++ b/container-search/src/main/java/com/yahoo/search/federation/sourceref/SourceRefResolver.java @@ -18,13 +18,15 @@ import com.yahoo.processing.request.Properties; * @author tonytv */ public class SourceRefResolver { + private final SearchChainResolver searchChainResolver; public SourceRefResolver(SearchChainResolver searchChainResolver) { this.searchChainResolver = searchChainResolver; } - public Set<SearchChainInvocationSpec> resolve(ComponentSpecification sourceRef, Properties sourceToProviderMap, - IndexFacts indexFacts) + public Set<SearchChainInvocationSpec> resolve(ComponentSpecification sourceRef, + Properties sourceToProviderMap, + IndexFacts indexFacts) throws UnresolvedSearchChainException { try { @@ -47,7 +49,7 @@ public class SourceRefResolver { clusterSearchChains.add(resolveClusterSearchChain(cluster, sourceRef, sourceToProviderMap)); } - if (!clusterSearchChains.isEmpty()) + if ( ! clusterSearchChains.isEmpty()) return clusterSearchChains; } @@ -55,17 +57,22 @@ public class SourceRefResolver { } - private SearchChainInvocationSpec resolveClusterSearchChain(String cluster, ComponentSpecification sourceRef, - Properties sourceToProviderMap) throws UnresolvedSearchChainException { + private SearchChainInvocationSpec resolveClusterSearchChain(String cluster, + ComponentSpecification sourceRef, + Properties sourceToProviderMap) + throws UnresolvedSearchChainException { try { return searchChainResolver.resolve(new ComponentSpecification(cluster), sourceToProviderMap); - } catch (UnresolvedSearchChainException e) { + } + catch (UnresolvedSearchChainException e) { throw new UnresolvedSearchChainException("Failed to resolve cluster search chain " + quote(cluster) + - " when using source ref " + quote(sourceRef) + " as a document name."); + " when using source ref " + quote(sourceRef) + + " as a document name."); } } private boolean hasOnlyName(ComponentSpecification sourceSpec) { return new ComponentSpecification(sourceSpec.getName()).equals(sourceSpec); } + } diff --git a/container-search/src/main/java/com/yahoo/search/federation/sourceref/Target.java b/container-search/src/main/java/com/yahoo/search/federation/sourceref/Target.java index 79b5d5bc67e..cf7276d767e 100644 --- a/container-search/src/main/java/com/yahoo/search/federation/sourceref/Target.java +++ b/container-search/src/main/java/com/yahoo/search/federation/sourceref/Target.java @@ -8,9 +8,10 @@ import com.yahoo.processing.request.Properties; /** * TODO: What's this? * -* @author tonytv -*/ + * @author tonytv + */ public abstract class Target extends AbstractComponent { + final ComponentId localId; final boolean isDerived; @@ -28,4 +29,5 @@ public abstract class Target extends AbstractComponent { public abstract String searchRefDescription(); abstract void freeze(); + } diff --git a/document/src/main/java/com/yahoo/document/annotation/AnnotationTypes.java b/document/src/main/java/com/yahoo/document/annotation/AnnotationTypes.java index e1e61e13119..248e37345a8 100644 --- a/document/src/main/java/com/yahoo/document/annotation/AnnotationTypes.java +++ b/document/src/main/java/com/yahoo/document/annotation/AnnotationTypes.java @@ -10,7 +10,7 @@ import java.util.List; * This is a container for all {@link Annotation}s constants used by built-in Vespa features. These must be in sync with * the corresponding class in the C++ file 'document/datatype/annotationtype.h'. * - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen */ @SuppressWarnings({ "UnusedDeclaration" }) public final class AnnotationTypes { diff --git a/document/src/main/java/com/yahoo/document/annotation/SpanTrees.java b/document/src/main/java/com/yahoo/document/annotation/SpanTrees.java index ff9c8c07a47..64d7d7cff68 100644 --- a/document/src/main/java/com/yahoo/document/annotation/SpanTrees.java +++ b/document/src/main/java/com/yahoo/document/annotation/SpanTrees.java @@ -4,7 +4,7 @@ package com.yahoo.document.annotation; /** * This is a container for all {@link SpanTree}s constants used by built-in Vespa features. * - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen */ @SuppressWarnings({ "UnusedDeclaration" }) // TODO: Remove. This is the wrong place. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java index 0fe73fad8ce..ee358f45b22 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java @@ -34,9 +34,7 @@ public class OperationMapper { public static TensorFlowOperation get(NodeDef node, List<TensorFlowOperation> inputs, int port) { switch (node.getOp().toLowerCase()) { - /* - * array ops - */ + // array ops case "const": return new Const(node, inputs, port); case "expanddims": return new ExpandDims(node, inputs, port); case "identity": return new Identity(node, inputs, port); @@ -46,15 +44,11 @@ public class OperationMapper { case "shape": return new Shape(node, inputs, port); case "squeeze": return new Squeeze(node, inputs, port); - /* - * control flow - */ + // control flow case "merge": return new Merge(node, inputs, port); case "switch": return new Switch(node, inputs, port); - /* - * math ops - */ + // math ops case "add": return new Join(node, inputs, port, ScalarFunctions.add()); case "add_n": return new Join(node, inputs, port, ScalarFunctions.add()); case "acos": return new Map(node, inputs, port, ScalarFunctions.acos()); @@ -75,27 +69,17 @@ public class OperationMapper { case "sub": return new Join(node, inputs, port, ScalarFunctions.subtract()); case "subtract": return new Join(node, inputs, port, ScalarFunctions.subtract()); - /* - * nn ops - */ + // nn ops case "biasadd": return new Join(node, inputs, port, ScalarFunctions.add()); case "elu": return new Map(node, inputs, port, ScalarFunctions.elu()); case "relu": return new Map(node, inputs, port, ScalarFunctions.relu()); case "selu": return new Map(node, inputs, port, ScalarFunctions.selu()); - /* - * random ops - */ - - /* - * state ops - */ + // state ops case "variable": return new Variable(node, inputs, port); case "variablev2": return new Variable(node, inputs, port); - /* - * evaluation no-ops - */ + // evaluation no-ops case "stopgradient":return new Identity(node, inputs, port); case "noop": return new NoOp(node, inputs, port); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java index 7decef51ab7..d06d7b48def 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -46,7 +47,7 @@ public class Const extends TensorFlowOperation { if (type.type().rank() == 0 && getConstantValue().isPresent()) { expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue()); } else { - expressionNode = new ReferenceNode("constant(\"" + vespaName() + "\")"); + expressionNode = new ReferenceNode(Reference.simple("constant", vespaName())); } return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode); } @@ -72,7 +73,7 @@ public class Const extends TensorFlowOperation { private Value value() { if (!node.getAttrMap().containsKey("value")) { throw new IllegalArgumentException("Node '" + node.getName() + "' of type " + - "const has missing 'value' attribute"); + "const has missing 'value' attribute"); } AttrValue attrValue = node.getAttrMap().get("value"); if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java index 9e8f6df3e2c..5d711aac100 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; @@ -67,7 +68,7 @@ public abstract class TensorFlowOperation { public Optional<TensorFunction> function() { if (function == null) { if (isConstant()) { - ExpressionNode constant = new ReferenceNode("constant(\"" + vespaName() + "\")"); + ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); } else if (outputs.size() > 1) { macro = lazyGetFunction(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 822d6055815..639c5d22d9e 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; @@ -105,7 +106,8 @@ public final class ReferenceNode extends CompositeNode { // TODO: Context should accept a Reference instead. if (reference.isIdentifier()) return context.get(reference.name()); - return context.get(getName(), getArguments(), getOutput()); + else + return context.get(getName(), getArguments(), getOutput()); } @Override diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index a796eaa4ac0..1ad8706aa2f 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -777,21 +777,10 @@ ConstantNode constantPrimitive() : ( <SUB> { sign = "-";} ) ? ( <INTEGER> { value = token.image; } | <FLOAT> { value = token.image; } | - value = stringPath() ) + <STRING> { value = token.image; } ) { return new ConstantNode(Value.parse(sign + value),sign + value); } } -// Strings separated by "/" -String stringPath() : -{ - StringBuilder b = new StringBuilder(); -} -{ - <STRING> { b.append(token.image); } - ( LOOKAHEAD(2) <DIV> <STRING> { b.append("/").append(token.image); } ) * - { return b.toString(); } -} - Value primitiveValue() : { String sign = ""; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/ReferenceTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/ReferenceTestCase.java new file mode 100644 index 00000000000..f275f95ca8e --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/ReferenceTestCase.java @@ -0,0 +1,33 @@ +package com.yahoo.searchlib.rankingexpression; + +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.NameNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import org.junit.Test; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +/** + * @author bratseth + */ +public class ReferenceTestCase { + + @Test + public void testSimple() { + assertTrue(new Reference("foo", new Arguments(new ReferenceNode("arg")), null).isSimple()); + assertTrue(new Reference("foo", new Arguments(new ReferenceNode("arg")), "out").isSimple()); + assertTrue(new Reference("foo", new Arguments(new NameNode("arg")), "out").isSimple()); + assertFalse(new Reference("foo", new Arguments(), null).isSimple()); + } + + @Test + public void testToString() { + assertEquals("foo(arg_1)", new Reference("foo", new Arguments(new ReferenceNode("arg_1")), null).toString()); + assertEquals("foo(arg_1).out", new Reference("foo", new Arguments(new ReferenceNode("arg_1")), "out").toString()); + assertEquals("foo(arg_1).out", new Reference("foo", new Arguments(new NameNode("arg_1")), "out").toString()); + assertEquals("foo", new Reference("foo", new Arguments(), null).toString()); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index c09b1f2b606..a13ff3147c8 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -32,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/BiasAdd", output.getName()); - assertEquals("join(reduce(join(tf_macro_X, constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))", + assertEquals("join(reduce(join(tf_macro_X, constant(outputs_kernel_read), f(a,b)(a * b)), sum, d2), constant(outputs_bias_read), f(a,b)(a + b))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index 60dd3865aa1..0deac3f8216 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -59,7 +59,7 @@ public class MnistSoftmaxImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("add", output.getName()); - assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"Variable_1_read\"), f(a,b)(a + b))", + assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(Variable_read), f(a,b)(a * b)), sum, d2), constant(Variable_1_read), f(a,b)(a + b))", output.getRoot().toString()); // Test execution diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index 9f372d8d6f5..daacd014b63 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -68,8 +68,8 @@ public class TestableTensorFlowModel { private Context contextFrom(TensorFlowModel result) { MapContext context = new MapContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); return context; } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java index 75b8e1122c1..135cc95a209 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java @@ -9,7 +9,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; /** - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen */ public class ReferenceNodeTestCase { |