summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2018-03-07 17:53:08 +0100
committerGitHub <noreply@github.com>2018-03-07 17:53:08 +0100
commitf435d9c6fe2bef62172aa1f18948459b402d0328 (patch)
treeec031359c8b3d414a24ed2ce92ee3cea5a979f22
parent25ff6f44faab887decc871e42b744fc5c06c1178 (diff)
parentde5472f3761f666aa5d990f0d49322f7f6425a76 (diff)
Merge pull request #5241 from vespa-engine/bratseth/tf-constants-in-parent-doc
Bratseth/tf constants in parent doc
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java81
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java69
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java6
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/federation/sourceref/SourceRefResolver.java21
-rw-r--r--container-search/src/main/java/com/yahoo/search/federation/sourceref/Target.java6
-rw-r--r--document/src/main/java/com/yahoo/document/annotation/AnnotationTypes.java2
-rw-r--r--document/src/main/java/com/yahoo/document/annotation/SpanTrees.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java28
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java4
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj13
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/ReferenceTestCase.java33
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java2
18 files changed, 205 insertions, 80 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index e81d22cefe9..2c177633590 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -9,9 +9,11 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.FeatureNames;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -51,6 +53,7 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Logger;
+import java.util.stream.Collectors;
/**
* Replaces instances of the tensorflow(model-path, signature, output)
@@ -85,10 +88,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
try {
ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(),
feature.getArguments());
- if (store.hasStoredModel())
- return transformFromStoredModel(store, context.rankProfile());
- else // not converted yet - access TensorFlow model files
+ if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files
return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles());
+ else
+ return transformFromStoredModel(store, context.rankProfile());
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
@@ -101,16 +104,21 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
k -> tensorFlowImporter.importModel(store.tensorFlowModelDir()));
+ // Add constants
+ Set<String> constantsReplacedByMacros = new HashSet<>();
+ model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
+ model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
+ constantsReplacedByMacros, k, v));
+
// Find the specified expression
Signature signature = chooseSignature(model, store.arguments().signature());
String output = chooseOutput(signature, store.arguments().output());
RankingExpression expression = model.expressions().get(output);
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
verifyRequiredMacros(expression, model.requiredMacros(), profile, queryProfiles);
store.writeConverted(expression);
- model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
- model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, k, v));
- model.macros().forEach((k, v) -> transformMacro(store, profile, k, v));
+ model.macros().forEach((k, v) -> transformMacro(store, profile, constantsReplacedByMacros, k, v));
return expression.getRoot();
}
@@ -189,17 +197,35 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
profile.addConstant(constantName, asValue(constantValue));
}
- private void transformLargeConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
- Path constantPath = store.writeLargeConstant(constantName, constantValue);
+ private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ Set<String> constantsReplacedByMacros,
+ String constantName, Tensor constantValue) {
+ RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
+ if (macroOverridingConstant != null) {
+ TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
+ if ( ! macroType.equals(constantValue.type()))
+ throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
+ "The required type of this is " + constantValue.type() +
+ ", but the macro returns " + macroType);
+ constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
+ }
+ else {
+
+ Path constantPath = store.writeLargeConstant(constantName, constantValue);
- if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
- log.info("Adding constant '" + constantName + "' of type " + constantValue.type());
- profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
- constantPath.toString()));
+ if (!profile.getSearch().getRankingConstants().containsKey(constantName)) {
+ log.info("Adding constant '" + constantName + "' of type " + constantValue.type());
+ profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
+ constantPath.toString()));
+ }
}
}
- private void transformMacro(ModelStore store, RankProfile profile, String macroName, RankingExpression expression) {
+ private void transformMacro(ModelStore store, RankProfile profile,
+ Set<String> constantsReplacedByMacros,
+ String macroName, RankingExpression expression) {
+
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
store.writeMacro(macroName, expression);
addMacroToProfile(profile, macroName, expression);
}
@@ -312,6 +338,35 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
return node;
}
+ /**
+ * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
+ * This method does that for the given expression and returns the result.
+ */
+ private RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ Set<String> constantsReplacedByMacros) {
+ if (constantsReplacedByMacros.isEmpty()) return expression;
+ return new RankingExpression(expression.getName(),
+ replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
+ }
+
+ private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ if (node instanceof ReferenceNode) {
+ Reference reference = ((ReferenceNode)node).reference();
+ if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
+ String argument = reference.simpleArgument().get();
+ if (constantsReplacedByMacros.contains(argument))
+ return new ReferenceNode(argument);
+ }
+ }
+ if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
+ CompositeNode composite = (CompositeNode)node;
+ return composite.setChildren(composite.children().stream()
+ .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
+ .collect(Collectors.toList()));
+ }
+ return node;
+ }
+
private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, List<String> reduceDimensions) {
return new TensorFunctionNode(new Reduce(function, Reduce.Aggregator.sum, reduceDimensions));
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 8e404e72ec7..06912a980a8 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -42,7 +42,7 @@ import static org.junit.Assert.*;
public class RankingExpressionWithTensorFlowTestCase {
private final Path applicationDir = Path.fromString("src/test/integration/tensorflow/");
- private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"layer_Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"layer_Variable_1_read\"), f(a,b)(a + b))";
+ private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b))";
@After
public void removeGeneratedConstantTensorFiles() {
@@ -252,8 +252,51 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
+ public void testImportingFromStoredExpressionsWithMacroOverridingConstant() throws IOException {
+ String rankProfile =
+ " rank-profile my_profile {\n" +
+ " macro Placeholder() {\n" +
+ " expression: tensor(d0[2],d1[784])(0.0)\n" +
+ " }\n" +
+ " macro layer_Variable_read() {\n" +
+ " expression: tensor(d1[10],d2[784])(0.0)\n" +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: tensorflow('mnist_softmax/saved')" +
+ " }\n" +
+ " }";
+
+
+ String vespaExpressionWithoutConstant =
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), layer_Variable_read, f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b))";
+ RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir));
+ search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
+
+ assertNull("Constant overridden by macro is not added",
+ search.search().getRankingConstants().get("layer_Variable_read"));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
+ try {
+ storedApplicationDirectory.toFile().mkdirs();
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
+ RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication);
+ searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
+ assertNull("Constant overridden by macro is not added",
+ searchFromStored.search().getRankingConstants().get("layer_Variable_read"));
+ assertLargeConstant("layer_Variable_1_read", searchFromStored, Optional.of(10L));
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
+ }
+ }
+
+ @Test
public void testTensorFlowReduceBatchDimension() {
- final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(\"layer_Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"layer_Variable_1_read\"), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(expression, "my_profile");
@@ -263,9 +306,9 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testMacroGeneration() {
- final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))";
- final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))";
+ final String expression = "join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b))";
+ final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))";
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist/saved')");
@@ -276,9 +319,9 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))";
- final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))";
+ final String expression = "join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b))";
+ final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
@@ -383,6 +426,16 @@ public class RankingExpressionWithTensorFlowTestCase {
}
}
+ private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) {
+ try {
+ return new RankProfileSearchFixture(application, application.getQueryProfiles(),
+ rankProfile, null, null);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
private static class StoringApplicationPackage extends MockApplicationPackage {
private final File root;
diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java
index 6ce08c75dd4..697cf910719 100644
--- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java
+++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java
@@ -330,14 +330,12 @@ public class FastSearcher extends VespaBackEndSearcher {
result.hits().addError(ErrorMessage.createBackendCommunicationError("Error filling hits with summary fields, source: " + getName()));
return;
}
- if (skippedHits==0 && packetWrapper != null) {
+ if (skippedHits == 0 && packetWrapper != null) {
cacheControl.updateCacheEntry(cacheKey, query, packetKeys, receivedPackets);
}
- if ( skippedHits>0 ) {
- getLogger().info("Could not fill summary '" + summaryClass + "' for " + skippedHits + " hits for query: " + result.getQuery());
+ if ( skippedHits > 0 )
result.hits().addError(com.yahoo.search.result.ErrorMessage.createEmptyDocsums("Missing hit data for summary '" + summaryClass + "' for " + skippedHits + " hits"));
- }
result.analyzeHits();
if (query.getTraceLevel() >= 3) {
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java
index 7a390ac279b..c61b043a36b 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java
@@ -245,7 +245,7 @@ public class Dispatcher extends AbstractComponent {
}
Inspector summaries = new SlimeAdapter(root.field("docsums"));
- if ( ! summaries.valid() && !hasErrors)
+ if ( ! summaries.valid() && ! hasErrors)
throw new IllegalArgumentException("Expected a Slime root object containing a 'docsums' field");
for (int i = 0; i < hits.size(); i++) {
fill(hits.get(i), summaries.entry(i).field("docsum"));
diff --git a/container-search/src/main/java/com/yahoo/search/federation/sourceref/SourceRefResolver.java b/container-search/src/main/java/com/yahoo/search/federation/sourceref/SourceRefResolver.java
index f690ad10050..e14c46056b7 100644
--- a/container-search/src/main/java/com/yahoo/search/federation/sourceref/SourceRefResolver.java
+++ b/container-search/src/main/java/com/yahoo/search/federation/sourceref/SourceRefResolver.java
@@ -18,13 +18,15 @@ import com.yahoo.processing.request.Properties;
* @author tonytv
*/
public class SourceRefResolver {
+
private final SearchChainResolver searchChainResolver;
public SourceRefResolver(SearchChainResolver searchChainResolver) {
this.searchChainResolver = searchChainResolver;
}
- public Set<SearchChainInvocationSpec> resolve(ComponentSpecification sourceRef, Properties sourceToProviderMap,
- IndexFacts indexFacts)
+ public Set<SearchChainInvocationSpec> resolve(ComponentSpecification sourceRef,
+ Properties sourceToProviderMap,
+ IndexFacts indexFacts)
throws UnresolvedSearchChainException {
try {
@@ -47,7 +49,7 @@ public class SourceRefResolver {
clusterSearchChains.add(resolveClusterSearchChain(cluster, sourceRef, sourceToProviderMap));
}
- if (!clusterSearchChains.isEmpty())
+ if ( ! clusterSearchChains.isEmpty())
return clusterSearchChains;
}
@@ -55,17 +57,22 @@ public class SourceRefResolver {
}
- private SearchChainInvocationSpec resolveClusterSearchChain(String cluster, ComponentSpecification sourceRef,
- Properties sourceToProviderMap) throws UnresolvedSearchChainException {
+ private SearchChainInvocationSpec resolveClusterSearchChain(String cluster,
+ ComponentSpecification sourceRef,
+ Properties sourceToProviderMap)
+ throws UnresolvedSearchChainException {
try {
return searchChainResolver.resolve(new ComponentSpecification(cluster), sourceToProviderMap);
- } catch (UnresolvedSearchChainException e) {
+ }
+ catch (UnresolvedSearchChainException e) {
throw new UnresolvedSearchChainException("Failed to resolve cluster search chain " + quote(cluster) +
- " when using source ref " + quote(sourceRef) + " as a document name.");
+ " when using source ref " + quote(sourceRef) +
+ " as a document name.");
}
}
private boolean hasOnlyName(ComponentSpecification sourceSpec) {
return new ComponentSpecification(sourceSpec.getName()).equals(sourceSpec);
}
+
}
diff --git a/container-search/src/main/java/com/yahoo/search/federation/sourceref/Target.java b/container-search/src/main/java/com/yahoo/search/federation/sourceref/Target.java
index 79b5d5bc67e..cf7276d767e 100644
--- a/container-search/src/main/java/com/yahoo/search/federation/sourceref/Target.java
+++ b/container-search/src/main/java/com/yahoo/search/federation/sourceref/Target.java
@@ -8,9 +8,10 @@ import com.yahoo.processing.request.Properties;
/**
* TODO: What's this?
*
-* @author tonytv
-*/
+ * @author tonytv
+ */
public abstract class Target extends AbstractComponent {
+
final ComponentId localId;
final boolean isDerived;
@@ -28,4 +29,5 @@ public abstract class Target extends AbstractComponent {
public abstract String searchRefDescription();
abstract void freeze();
+
}
diff --git a/document/src/main/java/com/yahoo/document/annotation/AnnotationTypes.java b/document/src/main/java/com/yahoo/document/annotation/AnnotationTypes.java
index e1e61e13119..248e37345a8 100644
--- a/document/src/main/java/com/yahoo/document/annotation/AnnotationTypes.java
+++ b/document/src/main/java/com/yahoo/document/annotation/AnnotationTypes.java
@@ -10,7 +10,7 @@ import java.util.List;
* This is a container for all {@link Annotation}s constants used by built-in Vespa features. These must be in sync with
* the corresponding class in the C++ file 'document/datatype/annotationtype.h'.
*
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
*/
@SuppressWarnings({ "UnusedDeclaration" })
public final class AnnotationTypes {
diff --git a/document/src/main/java/com/yahoo/document/annotation/SpanTrees.java b/document/src/main/java/com/yahoo/document/annotation/SpanTrees.java
index ff9c8c07a47..64d7d7cff68 100644
--- a/document/src/main/java/com/yahoo/document/annotation/SpanTrees.java
+++ b/document/src/main/java/com/yahoo/document/annotation/SpanTrees.java
@@ -4,7 +4,7 @@ package com.yahoo.document.annotation;
/**
* This is a container for all {@link SpanTree}s constants used by built-in Vespa features.
*
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
*/
@SuppressWarnings({ "UnusedDeclaration" })
// TODO: Remove. This is the wrong place.
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
index 0fe73fad8ce..ee358f45b22 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
@@ -34,9 +34,7 @@ public class OperationMapper {
public static TensorFlowOperation get(NodeDef node, List<TensorFlowOperation> inputs, int port) {
switch (node.getOp().toLowerCase()) {
- /*
- * array ops
- */
+ // array ops
case "const": return new Const(node, inputs, port);
case "expanddims": return new ExpandDims(node, inputs, port);
case "identity": return new Identity(node, inputs, port);
@@ -46,15 +44,11 @@ public class OperationMapper {
case "shape": return new Shape(node, inputs, port);
case "squeeze": return new Squeeze(node, inputs, port);
- /*
- * control flow
- */
+ // control flow
case "merge": return new Merge(node, inputs, port);
case "switch": return new Switch(node, inputs, port);
- /*
- * math ops
- */
+ // math ops
case "add": return new Join(node, inputs, port, ScalarFunctions.add());
case "add_n": return new Join(node, inputs, port, ScalarFunctions.add());
case "acos": return new Map(node, inputs, port, ScalarFunctions.acos());
@@ -75,27 +69,17 @@ public class OperationMapper {
case "sub": return new Join(node, inputs, port, ScalarFunctions.subtract());
case "subtract": return new Join(node, inputs, port, ScalarFunctions.subtract());
- /*
- * nn ops
- */
+ // nn ops
case "biasadd": return new Join(node, inputs, port, ScalarFunctions.add());
case "elu": return new Map(node, inputs, port, ScalarFunctions.elu());
case "relu": return new Map(node, inputs, port, ScalarFunctions.relu());
case "selu": return new Map(node, inputs, port, ScalarFunctions.selu());
- /*
- * random ops
- */
-
- /*
- * state ops
- */
+ // state ops
case "variable": return new Variable(node, inputs, port);
case "variablev2": return new Variable(node, inputs, port);
- /*
- * evaluation no-ops
- */
+ // evaluation no-ops
case "stopgradient":return new Identity(node, inputs, port);
case "noop": return new NoOp(node, inputs, port);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
index 7decef51ab7..d06d7b48def 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -46,7 +47,7 @@ public class Const extends TensorFlowOperation {
if (type.type().rank() == 0 && getConstantValue().isPresent()) {
expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue());
} else {
- expressionNode = new ReferenceNode("constant(\"" + vespaName() + "\")");
+ expressionNode = new ReferenceNode(Reference.simple("constant", vespaName()));
}
return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
}
@@ -72,7 +73,7 @@ public class Const extends TensorFlowOperation {
private Value value() {
if (!node.getAttrMap().containsKey("value")) {
throw new IllegalArgumentException("Node '" + node.getName() + "' of type " +
- "const has missing 'value' attribute");
+ "const has missing 'value' attribute");
}
AttrValue attrValue = node.getAttrMap().get("value");
if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
index 9e8f6df3e2c..5d711aac100 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
@@ -67,7 +68,7 @@ public abstract class TensorFlowOperation {
public Optional<TensorFunction> function() {
if (function == null) {
if (isConstant()) {
- ExpressionNode constant = new ReferenceNode("constant(\"" + vespaName() + "\")");
+ ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
} else if (outputs.size() > 1) {
macro = lazyGetFunction();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index 822d6055815..639c5d22d9e 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -105,7 +106,8 @@ public final class ReferenceNode extends CompositeNode {
// TODO: Context should accept a Reference instead.
if (reference.isIdentifier())
return context.get(reference.name());
- return context.get(getName(), getArguments(), getOutput());
+ else
+ return context.get(getName(), getArguments(), getOutput());
}
@Override
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index a796eaa4ac0..1ad8706aa2f 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -777,21 +777,10 @@ ConstantNode constantPrimitive() :
( <SUB> { sign = "-";} ) ?
( <INTEGER> { value = token.image; } |
<FLOAT> { value = token.image; } |
- value = stringPath() )
+ <STRING> { value = token.image; } )
{ return new ConstantNode(Value.parse(sign + value),sign + value); }
}
-// Strings separated by "/"
-String stringPath() :
-{
- StringBuilder b = new StringBuilder();
-}
-{
- <STRING> { b.append(token.image); }
- ( LOOKAHEAD(2) <DIV> <STRING> { b.append("/").append(token.image); } ) *
- { return b.toString(); }
-}
-
Value primitiveValue() :
{
String sign = "";
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/ReferenceTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/ReferenceTestCase.java
new file mode 100644
index 00000000000..f275f95ca8e
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/ReferenceTestCase.java
@@ -0,0 +1,33 @@
+package com.yahoo.searchlib.rankingexpression;
+
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.NameNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import org.junit.Test;
+
+import static junit.framework.TestCase.assertTrue;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+/**
+ * @author bratseth
+ */
+public class ReferenceTestCase {
+
+ @Test
+ public void testSimple() {
+ assertTrue(new Reference("foo", new Arguments(new ReferenceNode("arg")), null).isSimple());
+ assertTrue(new Reference("foo", new Arguments(new ReferenceNode("arg")), "out").isSimple());
+ assertTrue(new Reference("foo", new Arguments(new NameNode("arg")), "out").isSimple());
+ assertFalse(new Reference("foo", new Arguments(), null).isSimple());
+ }
+
+ @Test
+ public void testToString() {
+ assertEquals("foo(arg_1)", new Reference("foo", new Arguments(new ReferenceNode("arg_1")), null).toString());
+ assertEquals("foo(arg_1).out", new Reference("foo", new Arguments(new ReferenceNode("arg_1")), "out").toString());
+ assertEquals("foo(arg_1).out", new Reference("foo", new Arguments(new NameNode("arg_1")), "out").toString());
+ assertEquals("foo", new Reference("foo", new Arguments(), null).toString());
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index c09b1f2b606..a13ff3147c8 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -32,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/BiasAdd", output.getName());
- assertEquals("join(reduce(join(tf_macro_X, constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(tf_macro_X, constant(outputs_kernel_read), f(a,b)(a * b)), sum, d2), constant(outputs_bias_read), f(a,b)(a + b))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index 60dd3865aa1..0deac3f8216 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -59,7 +59,7 @@ public class MnistSoftmaxImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("add", output.getName());
- assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"Variable_1_read\"), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(Variable_read), f(a,b)(a * b)), sum, d2), constant(Variable_1_read), f(a,b)(a + b))",
output.getRoot().toString());
// Test execution
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index 9f372d8d6f5..daacd014b63 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -68,8 +68,8 @@ public class TestableTensorFlowModel {
private Context contextFrom(TensorFlowModel result) {
MapContext context = new MapContext();
- result.largeConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
- result.smallConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
return context;
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java
index 75b8e1122c1..135cc95a209 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java
@@ -9,7 +9,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals;
/**
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
*/
public class ReferenceNodeTestCase {