diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-11 16:09:29 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-11 16:09:29 +0100 |
commit | 125bdaf97ebc63c5c6892ce4311418f0568908ce (patch) | |
tree | bac16eb938092892dec00483fc2503ba881da8c9 /config-model | |
parent | ad378e79897bfa229ddd84365f211889b1703671 (diff) |
Complete support for TF constants as files (deactivated)
Diffstat (limited to 'config-model')
10 files changed, 85 insertions, 39 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index bd7b8ce6e15..f37ab9fb89f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -171,8 +171,9 @@ public class Search implements Serializable, ImmutableSearch { rankingConstants.put(name, constant); } - public Iterable<RankingConstant> getRankingConstants() { - return rankingConstants.values(); + /** Returns a read-only map of the ranking constants in this indexed by name */ + public Map<String, RankingConstant> getRankingConstants() { + return Collections.unmodifiableMap(rankingConstants); } public Optional<TemporaryImportedFields> temporaryImportedFields() { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java index 5f9a4d98b43..67d60b08ab0 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; import com.google.common.collect.ImmutableList; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index a36384ce6f2..32f8f4871df 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -1,7 +1,9 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; import com.google.common.base.Joiner; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.IOUtils; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; @@ -15,8 +17,11 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -28,8 +33,13 @@ import java.util.Optional; * * @author bratseth */ +// TODO: - Verify types of macros +// - Avoid name conflicts across models for constants public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { + // TODO: Make system test work with this set to true, then remove the "true" path + private static final boolean constantsInConfig = true; + private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ @@ -63,7 +73,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil optionalArgument(2, feature.getArguments())); // Add all constants (after finding outputs to fail faster when the output is not found) - if (1==1) + if (constantsInConfig) result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v))); else // correct way, disabled for now result.constants().forEach((k, v) -> transformConstant(modelPath, context.rankProfile(), k, v)); @@ -129,15 +139,23 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } private void transformConstant(String modelPath, RankProfile profile, String constantName, Tensor constantValue) { - File constantFilePath = new File(modelPath, "converted_variables"); - if ( ! constantFilePath.exists() ) { - if ( ! constantFilePath.mkdir() ) - throw new IllegalStateException("Could not create directory " + constantFilePath); - } + try { + if (profile.getSearch().getRankingConstants().containsKey(constantName)) return; + + File constantFilePath = new File(modelPath, "converted_variables").getCanonicalFile(); + if (!constantFilePath.exists()) { + if (!constantFilePath.mkdir()) + throw new IOException("Could not create directory " + constantFilePath); + } - File constantFile = new File(constantFilePath, constantName + ".json"); - // writeAsVespaTensor(constantValue, constantFile); - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFilePath.getPath())); + // "tbf" ending for "typed binary format" - recognized by the nodes reciving the file: + File constantFile = new File(constantFilePath, constantName + ".tbf"); + IOUtils.writeFile(constantFile, TypedBinaryFormat.encode(constantValue)); + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFile.getPath())); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } } private String skippedOutputsDescription(TensorFlowModel.Signature signature) { @@ -170,5 +188,4 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil return c == '\'' || c == '"'; } - } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java index 43c0817c986..5255cdaeba1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java @@ -205,11 +205,9 @@ public class TensorTransformer extends ExpressionTransformer<RankProfileTransfor } private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context, RankProfile profile) { - for (RankingConstant rankingConstant : profile.getSearch().getRankingConstants()) { - if (rankingConstant.getName().equals(name)) { - context.put(node.toString(), emptyTensorValue(rankingConstant.getTensorType())); - } - } + RankingConstant constant = profile.getSearch().getRankingConstants().get(name); + if (constant != null) + context.put(node.toString(), emptyTensorValue(constant.getTensorType())); } private void addIfQuery(ReferenceNode node, Context context, RankProfile profile) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java index a1f372f2307..6ae9883c082 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java @@ -48,7 +48,7 @@ public class RankingConstantsValidator extends Validator { ExceptionMessageCollector exceptionMessageCollector = new ExceptionMessageCollector("Invalid constant tensor file(s):"); for (SearchDefinition sd : deployState.getSearchDefinitions()) { - for (RankingConstant rc : sd.getSearch().getRankingConstants()) { + for (RankingConstant rc : sd.getSearch().getRankingConstants().values()) { try { validateRankingConstant(rc, applicationPackage); } catch (InvalidConstantTensor | FileNotFoundException ex) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java index e3eb66e6a18..fd062dc4ea4 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java @@ -37,7 +37,7 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer public void prepareToDistributeFiles(List<SearchNode> backends) { for (SearchDefinitionSpec sds : localSDS) { - for (RankingConstant constant : sds.getSearchDefinition().getSearch().getRankingConstants()) { + for (RankingConstant constant : sds.getSearchDefinition().getSearch().getRankingConstants().values()) { FileReference reference = FileSender.sendFileToServices(constant.getFileName(), backends); constant.setFileReference(reference.value()); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java index 32548039fdd..1413d515103 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java @@ -75,7 +75,7 @@ public class DocumentDatabase extends AbstractConfigProducer implements @Override public void getConfig(RankingConstantsConfig.Builder builder) { - for (RankingConstant constant : derivedCfg.getSearch().getRankingConstants()) { + for (RankingConstant constant : derivedCfg.getSearch().getRankingConstants().values()) { if ("".equals(constant.getFileReference())) { System.err.println("INVALID rank constant "+constant.getName()+" [missing file reference]"); // TODO: Throw or log warning continue; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingConstantTest.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingConstantTest.java index 38aa9a5d53a..2880af9e74f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingConstantTest.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingConstantTest.java @@ -43,7 +43,7 @@ public class RankingConstantTest { searchBuilder.build(); Search search = searchBuilder.getSearch(); - Iterator<RankingConstant> constantIterator = search.getRankingConstants().iterator(); + Iterator<RankingConstant> constantIterator = search.getRankingConstants().values().iterator(); RankingConstant constant = constantIterator.next(); assertEquals(TENSOR_NAME, constant.getName()); assertEquals(TENSOR_FILE, constant.getFileName()); @@ -99,7 +99,7 @@ public class RankingConstantTest { )); searchBuilder.build(); Search search = searchBuilder.getSearch(); - RankingConstant constant = search.getRankingConstants().iterator().next(); + RankingConstant constant = search.getRankingConstants().values().iterator().next(); assertEquals("simplename", constant.getFileName()); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index e71a627d7db..ff53fdafacf 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -55,4 +55,7 @@ class RankProfileSearchFixture { public RankProfile rankProfile(String rankProfile) { return rankProfileRegistry.getRankProfile(search, rankProfile).compile(); } + + public Search search() { return search; } + } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 3ec621618e5..31f7511155b 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -1,13 +1,25 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.io.IOUtils; +import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.yolean.Exceptions; +import org.junit.After; import org.junit.Test; +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Optional; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** @@ -19,6 +31,11 @@ public class RankingExpressionWithTensorFlowTestCase { private final String modelDirectory = "../src/test/integration/tensorflow/mnist_softmax/saved"; private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(Variable_1), d0, d1), f(a,b)(a + b))"; + @After + public void removeGeneratedConstantTensorFiles() { + IOUtils.recursiveDeleteDir(new File(modelDirectory.substring(3), "converted_variables")); + } + @Test public void testMinimalTensorFlowReference() throws ParseException { RankProfileSearchFixture search = new RankProfileSearchFixture( @@ -28,14 +45,8 @@ public class RankingExpressionWithTensorFlowTestCase { " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - - Tensor variable_1 = search.rankProfile("my_profile").getConstants().get("Variable_1").asTensor(); - assertNotNull("Variable_1 is imported", variable_1); - assertEquals(10, variable_1.size()); - - Tensor variable = search.rankProfile("my_profile").getConstants().get("Variable").asTensor(); - assertNotNull("Variable is imported", variable); - assertEquals(7840, variable.size()); + assertConstant(10, "Variable_1", search); + assertConstant(7840, "Variable", search); } @Test @@ -47,14 +58,8 @@ public class RankingExpressionWithTensorFlowTestCase { " }\n" + " }"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - - Tensor variable_1 = search.rankProfile("my_profile").getConstants().get("Variable_1").asTensor(); - assertNotNull("Variable_1 is imported", variable_1); - assertEquals(10, variable_1.size()); - - Tensor variable = search.rankProfile("my_profile").getConstants().get("Variable").asTensor(); - assertNotNull("Variable is imported", variable); - assertEquals(7840, variable.size()); + assertConstant(10, "Variable_1", search); + assertConstant(7840, "Variable", search); } @Test @@ -117,4 +122,25 @@ public class RankingExpressionWithTensorFlowTestCase { } } + private void assertConstant(int expectedSize, String name, RankProfileSearchFixture search) { + try { + TensorValue constant = (TensorValue)search.rankProfile("my_profile").getConstants().get(name); // Old way. TODO: Remove + if (constant == null) { // New way + File constantFile = new File(modelDirectory.substring(3) + "/converted_variables", name + ".tbf"); + RankingConstant rankingConstant = search.search().getRankingConstants().get(name); + assertEquals(name, rankingConstant.getName()); + assertEquals(constantFile.getAbsolutePath(), rankingConstant.getFileName()); + assertTrue("Constant file has been written", constantFile.exists()); + Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantFile))); + assertEquals(expectedSize, deserializedConstant.size()); + } else { // Old way. TODO: Remove + assertNotNull(name + " is imported", constant); + assertEquals(expectedSize, constant.asTensor().size()); + } + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } |