From 3783a9b21f8ab7ca3700903d9780a9f7374cf0c5 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 13 Dec 2017 15:21:44 +0100 Subject: Check agreement between TF and Vespa execution --- .../yahoo/searchdefinition/TensorTransformer.java | 2 +- .../yahoo/searchdefinition/document/Attribute.java | 2 +- .../document/ImmutableImportedSDField.java | 18 +- .../configmodel/producers/DocumentManager.java | 2 +- .../validation/ConstantTensorJsonValidator.java | 2 +- .../validation/RankingConstantsValidator.java | 2 +- .../searchdefinition/RankProfileTestCase.java | 2 +- .../derived/ExportingTestCase.java | 2 +- .../processing/TensorTransformTestCase.java | 2 +- .../com/yahoo/prelude/fastsearch/DocsumField.java | 12 +- .../java/com/yahoo/prelude/fastsearch/FastHit.java | 8 +- .../com/yahoo/prelude/fastsearch/TensorField.java | 2 +- .../query/profile/types/TensorFieldType.java | 2 +- .../prelude/fastsearch/SlimeSummaryTestCase.java | 2 +- .../java/com/yahoo/search/test/QueryTestCase.java | 29 +- .../src/main/java/com/yahoo/document/DataType.java | 8 +- .../com/yahoo/document/DocumentTypeManager.java | 12 +- .../java/com/yahoo/document/TensorDataType.java | 6 +- .../yahoo/document/datatypes/TensorFieldValue.java | 4 +- .../yahoo/document/json/JsonReaderTestCase.java | 8 +- .../TensorFieldValueSerializationTestCase.java | 2 +- .../rankingexpression/evaluation/Context.java | 44 +- .../evaluation/DoubleCompatibleValue.java | 8 + .../rankingexpression/evaluation/StringValue.java | 9 +- .../rankingexpression/evaluation/TensorValue.java | 14 +- .../rankingexpression/evaluation/Value.java | 10 + .../integration/tensorflow/ImportResult.java | 51 + .../integration/tensorflow/NamedTensor.java | 23 - .../integration/tensorflow/OperationMapper.java | 115 +- .../integration/tensorflow/TensorConverter.java | 6 +- .../integration/tensorflow/TensorFlowImporter.java | 114 +- .../tensorflow/TypedTensorFunction.java | 4 +- .../rule/GeneratorLambdaFunctionNode.java | 14 +- .../rankingexpression/rule/TensorFunctionNode.java | 9 +- .../mnist_softmax/mnist_sftmax_with_saving.py | 89 + .../mnist_softmax/saved/saved_model.pbtxt | 5039 ++++++++++++++++++++ .../saved/variables/variables.data-00000-of-00001 | Bin 0 -> 31400 bytes .../mnist_softmax/saved/variables/variables.index | Bin 0 -> 159 bytes .../tensorflow/model1/saved_model.pbtxt | 4909 ------------------- .../model1/variables/variables.data-00000-of-00001 | Bin 31400 -> 0 bytes .../tensorflow/model1/variables/variables.index | Bin 159 -> 0 bytes .../evaluation/EvaluationTestCase.java | 13 +- .../evaluation/EvaluationTester.java | 4 +- .../tensorflow/Mnist_SoftmaxTestCase.java | 114 + .../tensorflow/TensorFlowImporterTestCase.java | 79 - .../searchlib/tensor/TensorConformanceTest.java | 4 +- .../main/java/com/yahoo/tensor/DimensionSizes.java | 4 +- .../main/java/com/yahoo/tensor/IndexedTensor.java | 164 +- .../main/java/com/yahoo/tensor/MappedTensor.java | 16 +- .../main/java/com/yahoo/tensor/MixedTensor.java | 6 +- .../src/main/java/com/yahoo/tensor/Tensor.java | 46 +- .../main/java/com/yahoo/tensor/TensorAddress.java | 24 +- .../main/java/com/yahoo/tensor/TensorParser.java | 2 +- .../src/main/java/com/yahoo/tensor/TensorType.java | 45 +- .../yahoo/tensor/evaluation/EvaluationContext.java | 9 +- .../tensor/evaluation/MapEvaluationContext.java | 4 +- .../yahoo/tensor/evaluation/VariableTensor.java | 8 +- .../tensor/functions/CompositeTensorFunction.java | 2 +- .../java/com/yahoo/tensor/functions/Concat.java | 10 +- .../com/yahoo/tensor/functions/ConstantTensor.java | 6 +- .../main/java/com/yahoo/tensor/functions/Diag.java | 8 +- .../java/com/yahoo/tensor/functions/Generate.java | 12 +- .../main/java/com/yahoo/tensor/functions/Join.java | 32 +- .../main/java/com/yahoo/tensor/functions/Map.java | 2 +- .../java/com/yahoo/tensor/functions/Matmul.java | 8 +- .../tensor/functions/PrimitiveTensorFunction.java | 4 +- .../java/com/yahoo/tensor/functions/Random.java | 6 +- .../java/com/yahoo/tensor/functions/Range.java | 8 +- .../java/com/yahoo/tensor/functions/Reduce.java | 34 +- .../java/com/yahoo/tensor/functions/Rename.java | 22 +- .../com/yahoo/tensor/functions/TensorFunction.java | 6 +- .../yahoo/tensor/serialization/BinaryFormat.java | 2 +- .../tensor/serialization/DenseBinaryFormat.java | 6 +- .../tensor/serialization/TypedBinaryFormat.java | 6 +- .../com/yahoo/tensor/TensorFunctionBenchmark.java | 14 +- .../test/java/com/yahoo/tensor/TensorTestCase.java | 14 +- .../com/yahoo/tensor/functions/JoinTestCase.java | 10 +- .../tensor/functions/TensorFunctionTestCase.java | 6 +- .../serialization/DenseBinaryFormatTestCase.java | 4 +- .../serialization/MixedBinaryFormatTestCase.java | 2 +- .../serialization/SerializationTestCase.java | 4 +- .../serialization/SparseBinaryFormatTestCase.java | 2 +- 82 files changed, 5842 insertions(+), 5517 deletions(-) create mode 100644 searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java delete mode 100644 searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java create mode 100644 searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py create mode 100644 searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt create mode 100644 searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 create mode 100644 searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index delete mode 100644 searchlib/src/test/files/integration/tensorflow/model1/saved_model.pbtxt delete mode 100644 searchlib/src/test/files/integration/tensorflow/model1/variables/variables.data-00000-of-00001 delete mode 100644 searchlib/src/test/files/integration/tensorflow/model1/variables/variables.index create mode 100644 searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java delete mode 100644 searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java index 69e353ceb35..65176006a2a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java @@ -118,7 +118,7 @@ public class TensorTransformer extends ExpressionTransformer { private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) { ExpressionNode arg1 = node.children().get(0); ExpressionNode arg2 = node.children().get(1); - + TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1); Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); String dimension = ((ReferenceNode) arg2).getName(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java index c52a5dc465d..f932265cb93 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java @@ -259,7 +259,7 @@ public final class Attribute implements Cloneable, Serializable { throw new IllegalArgumentException("Field " + fieldType + " not supported in convertCollectionType"); } } - + private static Optional convertTensorType(DataType fieldType) { if ( ! ( fieldType instanceof TensorDataType)) return Optional.empty(); return Optional.of(((TensorDataType)fieldType).getTensorType()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java index a70e77a17a2..d62d1f5200a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java @@ -29,7 +29,7 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public boolean containsExpression(Class searchFor) { - throw createUnsupportedException(); + throw createUnsupportedException(searchFor.getSimpleName()); } @Override @@ -79,7 +79,7 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public Index getIndex(String name) { - throw createUnsupportedException(); + throw createUnsupportedException("index"); } @Override @@ -99,7 +99,7 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public ScriptExpression getIndexingScript() { - throw createUnsupportedException(); + throw createUnsupportedException("indexing"); } @Override @@ -114,12 +114,12 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public ImmutableSDField getStructField(String name) { - throw createUnsupportedException(); + throw createUnsupportedException("struct"); } @Override public Collection getStructFields() { - throw createUnsupportedException(); + throw createUnsupportedException("struct"); } @Override @@ -129,12 +129,12 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public Stemming getStemming(Search search) { - throw createUnsupportedException(); + throw createUnsupportedException("stemming"); } @Override public Ranking getRanking() { - throw createUnsupportedException(); + throw createUnsupportedException("ranking"); } @Override @@ -153,8 +153,8 @@ public class ImmutableImportedSDField implements ImmutableSDField { importedField.targetField().getDataType()); } - private static UnsupportedOperationException createUnsupportedException() { - return new UnsupportedOperationException("This aspect is not meaningful or relevant for an imported field."); + private static UnsupportedOperationException createUnsupportedException(String aspect) { + return new UnsupportedOperationException("'" + aspect + "' is not meaningful or relevant for an imported field."); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java b/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java index 96a9448739a..9368d6aaa39 100644 --- a/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java +++ b/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java @@ -20,7 +20,7 @@ import java.util.Set; */ public class DocumentManager { - public DocumentmanagerConfig.Builder produce(DocumentModel model, + public DocumentmanagerConfig.Builder produce(DocumentModel model, DocumentmanagerConfig.Builder documentConfigBuilder) { documentConfigBuilder.enablecompression(false); Set handled = new HashSet<>(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java index 6eeb12ffdd9..ce3c04f41f7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java @@ -45,7 +45,7 @@ public class ConstantTensorJsonValidator { throw new IllegalArgumentException("Ranking constant file names must end with either '.json' or '.json.lz4'"); } } - + private void validateTensor(TensorType type, Reader tensorData) { wrapIOException(() -> { this.parser = jsonFactory.createParser(tensorData); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java index 4a9310799aa..c686f023d5b 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java @@ -64,7 +64,7 @@ public class RankingConstantsValidator extends Validator { private void validateRankingConstant(RankingConstant rankingConstant, ApplicationPackage applicationPackage) throws FileNotFoundException { ApplicationFile tensorApplicationFile = applicationPackage.getFile(Path.fromString(rankingConstant.getFileName())); - new ConstantTensorJsonValidator().validate(rankingConstant.getFileName(), + new ConstantTensorJsonValidator().validate(rankingConstant.getFileName(), rankingConstant.getTensorType(), tensorApplicationFile.createReader()); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index 9407c21fee8..960a3b7d6db 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -173,7 +173,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.tensor3").isPresent()); assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.numeric").isPresent()); } - + private static Optional findProperty(List> properties, String key) { for (Pair property : properties) if (property.getFirst().equals(key)) diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java index 7cd00e155bb..4600f6ae4c6 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java @@ -123,7 +123,7 @@ public class ExportingTestCase extends AbstractExportingTestCase { public void testIndexinfoFieldsets() throws IOException, ParseException { assertCorrectDeriving("indexinfo_fieldsets"); } - + @Test public void testStreamingJuniper() throws IOException, ParseException { assertCorrectDeriving("streamingjuniper"); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index 12bdd8d2b5c..e5693d24f0f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -202,5 +202,5 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { } return b.toString(); } - + } diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java index fc1bbace092..1e44a8fa64d 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java @@ -1,16 +1,16 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.prelude.fastsearch; +import com.yahoo.container.search.LegacyEmulationConfig; +import com.yahoo.data.access.Inspector; +import com.yahoo.log.LogLevel; + import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Map; import java.util.logging.Logger; -import com.yahoo.data.access.Inspector; -import com.yahoo.container.search.LegacyEmulationConfig; - -import com.yahoo.log.LogLevel; /** * @author Bjørn Borud @@ -25,7 +25,7 @@ public abstract class DocsumField { Map> constructors = new HashMap<>(); - void put(String typename, Class fieldClass) + void put(String typename, Class fieldClass) throws NoSuchMethodException, SecurityException { Constructor constructor = fieldClass.getConstructor(String.class); constructors.put(typename, constructor); @@ -106,7 +106,7 @@ public abstract class DocsumField { public abstract Object decode(ByteBuffer b); /** - * Get the number of bytes this field occupies in the given buffer + * Get the number of bytes this field occupies in the given buffer * AND SET(!) the position to the first byte after this field. */ public abstract int getLength(ByteBuffer b); diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java index 1524a4da426..692e93bed7e 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java @@ -109,7 +109,7 @@ public class FastHit extends Hit { /** * Returns the explicitly set uri if available, returns "index:[source]/[partid]/[id]" otherwise - * + * * @return uri of hit */ public URI getUri() { @@ -128,9 +128,9 @@ public class FastHit extends Hit { } /** - * The uri of the index location of this hit ("index:[source]/[partid]/[id]"). + * The uri of the index location of this hit ("index:[source]/[partid]/[id]"). * This is the uri if no other uri is assigned - * + * * @return uri to the index. */ public URI getIndexUri() { @@ -215,7 +215,7 @@ public class FastHit extends Hit { * The empty string ("") if no value is assigned in the document. * *
  • Dynamic summary string fields: A Java String before JuniperSearcher and a HitField after.
  • - * + * *
  • Numerics: The corresponding numeric Java type.
    * If the field has no value assigned in the document, * the special numeric {@link com.yahoo.search.result.NanNumber#NaN} is returned. diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java index e0ca7fbe6e1..d8b38667224 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java @@ -13,7 +13,7 @@ import java.util.Optional; /** * A tensor field. Tensors are encoded as a data field where the data (following the length) * is encoded in a tensor binary format defined by com.yahoo.tensor.serialization.TypedBinaryFormat - * + * * @author bratseth */ public class TensorField extends DocsumField implements VariableLengthField { diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index 0ec15b95b0d..0fd529bf262 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -16,7 +16,7 @@ import java.util.Optional; public class TensorFieldType extends FieldType { // TODO: Require tensor type - + private final Optional type; /** Creates a tensor field type with optional information about the kind of tensor this will hold */ diff --git a/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java b/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java index 15a0fd60511..5494d1965f8 100644 --- a/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java @@ -102,7 +102,7 @@ public class SlimeSummaryTestCase { public void testDecoding() { Tensor tensor1 = Tensor.from("tensor(x{},y{}):{{x:foo,y:bar}:0.1}"); Tensor tensor2 = Tensor.from("tensor(x[],y[1]):{{x:0,y:0}:-0.3}"); - + String summary_cf = "file:src/test/java/com/yahoo/prelude/fastsearch/summary.cfg"; DocsumDefinitionSet set = createDocsumDefinitionSet(summary_cf); byte[] docsum = makeDocsum(tensor1, tensor2); diff --git a/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java b/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java index e59c03b33c3..62eacaa0afe 100644 --- a/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java @@ -2,8 +2,6 @@ package com.yahoo.search.test; import com.yahoo.component.chain.Chain; -import com.yahoo.language.Language; -import com.yahoo.language.Linguistics; import com.yahoo.language.detect.Detection; import com.yahoo.language.detect.Detector; import com.yahoo.language.detect.Hint; @@ -28,7 +26,6 @@ import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.search.result.Hit; import com.yahoo.search.searchchain.Execution; - import com.yahoo.yolean.Exceptions; import org.junit.Ignore; import org.junit.Test; @@ -45,14 +42,14 @@ import java.util.stream.Collectors; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.is; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** @@ -69,7 +66,7 @@ public class QueryTestCase { assertEquals("", q.properties().get("aParameter")); assertNull(q.properties().get("notSetParameter")); } - + // TODO: YQL work in progress (jon) @Ignore @Test @@ -693,7 +690,7 @@ public class QueryTestCase { List l = QueryTree.getPositiveTerms(i); assertEquals(3, l.size()); } - + @Test public void testHeuristicLanguageDetectionTextExtraction() { assertDetectionText("b ", "a:b", "text:a", "text:default"); @@ -720,27 +717,27 @@ public class QueryTestCase { q.getModel().getQueryTree(); // cause parsing assertEquals(expectedDetectionText, mockLinguistics.detector.lastDetectionText); } - + /** A linguistics instance which records the last language detection text passed to it */ private static class MockLinguistics extends SimpleLinguistics { final MockDetector detector = new MockDetector(); - + @Override public Detector getDetector() { return detector; } - + } - + private static class MockDetector extends SimpleDetector { String lastDetectionText = null; - + @Override public Detection detect(String input, Hint hint) { lastDetectionText = input; return super.detect(input, hint); } - + } protected boolean contains(String lineSubstring,String[] lines) { diff --git a/document/src/main/java/com/yahoo/document/DataType.java b/document/src/main/java/com/yahoo/document/DataType.java index c8a04866aa9..abdbf394591 100644 --- a/document/src/main/java/com/yahoo/document/DataType.java +++ b/document/src/main/java/com/yahoo/document/DataType.java @@ -51,7 +51,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com public final static PrimitiveDataType URI = new PrimitiveDataType("uri", 10, UriFieldValue.class, new UriFieldValue.Factory()); public final static NumericDataType BYTE = new NumericDataType("byte", 16, ByteFieldValue.class, ByteFieldValue.getFactory()); public final static PrimitiveDataType PREDICATE = new PrimitiveDataType("predicate", 20, PredicateFieldValue.class, PredicateFieldValue.getFactory()); - public final static int tensorDataTypeCode = 21; // All TensorDataType instances have id=21 but carries additional type information serialized separately + public final static int tensorDataTypeCode = 21; // All TensorDataType instances have id=21 but carries additional type information serialized separately // ADDITIONAL parametrized types added at runtime: map, struct, array, weighted set, annotation reference, tensor // Tags are converted to weightedset when reading the search definition TODO: Remove it @@ -99,7 +99,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com /** * Creates a field value by reflection - * + * * @param arg the value of the newly created field value * @return a fully constructed value */ @@ -201,7 +201,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com public static TensorDataType getTensor(TensorType type) { return new TensorDataType(type); } - + public String getName() { return name; } @@ -267,7 +267,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com */ public FieldPath buildFieldPath(String fieldPathString) { if (fieldPathString.length() > 0) { - throw new IllegalArgumentException("Datatype " + toString() + + throw new IllegalArgumentException("Datatype " + toString() + " does not support further recursive structure: " + fieldPathString); } return new FieldPath(); diff --git a/document/src/main/java/com/yahoo/document/DocumentTypeManager.java b/document/src/main/java/com/yahoo/document/DocumentTypeManager.java index 8c9318199d8..5fad35a2287 100644 --- a/document/src/main/java/com/yahoo/document/DocumentTypeManager.java +++ b/document/src/main/java/com/yahoo/document/DocumentTypeManager.java @@ -38,7 +38,7 @@ public class DocumentTypeManager { // *Configured data types* (not built-in/primitive) indexed by their id // // *Primitive* data types are always available and have a single id. - // + // // *Built-in dynamic* types: The tensor type. // Any tensor type has the same id and is always available just like primitive types. // However, unlike primitive types, each tensor type is a separate DataType instance @@ -112,7 +112,7 @@ public class DocumentTypeManager { public DataType getDataType(String name) { if (name.startsWith("tensor(")) // built-in dynamic return new TensorDataType(TensorType.fromSpec(name)); - + List foundTypes = new ArrayList<>(); for (DataType type : dataTypes.values()) { if (type.getName().equalsIgnoreCase(name)) { @@ -141,10 +141,10 @@ public class DocumentTypeManager { } public DataType getDataType(int code) { return getDataType(code, ""); } - + /** * Return a data type instance - * + * * @param code the code of the data type to return, which must be either built in or present in this manager * @param detailedType detailed type information, or the empty string if none * @return the appropriate DataType instance @@ -183,7 +183,7 @@ public class DocumentTypeManager { /** * Register a single datatype. Re-registering an existing, but equal, datatype is ok. - * + * * @param type The datatype to register */ void registerSingleType(DataType type) { @@ -280,7 +280,7 @@ public class DocumentTypeManager { /** * Returns a read only view of the registered data types - * + * * @return collection of types */ public Collection getDataTypes() { diff --git a/document/src/main/java/com/yahoo/document/TensorDataType.java b/document/src/main/java/com/yahoo/document/TensorDataType.java index aefdc030a12..50e9cf0f60f 100644 --- a/document/src/main/java/com/yahoo/document/TensorDataType.java +++ b/document/src/main/java/com/yahoo/document/TensorDataType.java @@ -8,13 +8,13 @@ import com.yahoo.vespa.objects.Ids; /** * A DataType containing a tensor type - * + * * @author bratseth */ public class TensorDataType extends DataType { private final TensorType tensorType; - + // The global class identifier shared with C++. public static int classId = registerClass(Ids.document + 59, TensorDataType.class); @@ -47,5 +47,5 @@ public class TensorDataType extends DataType { /** Returns the type of the tensor this field can hold */ public TensorType getTensorType() { return tensorType; } - + } diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java index ae8d5cf596a..1808396986e 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java +++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java @@ -19,7 +19,7 @@ import java.util.Optional; public class TensorFieldValue extends FieldValue { private Optional tensor; - + private final TensorDataType dataType; /** Create an empty tensor field value */ @@ -66,7 +66,7 @@ public class TensorFieldValue extends FieldValue { o.getClass().getName() + "'."); } } - + public void assignTensor(Optional tensor) { if (tensor.isPresent() && ! tensor.get().type().isAssignableTo(dataType.getTensorType())) throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() + diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index f37fa5ea675..29ba244a9f1 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -146,9 +146,9 @@ public class JsonReaderTestCase { } { DocumentType x = new DocumentType("testtensor"); - x.addField(new Field("mappedtensorfield", + x.addField(new Field("mappedtensorfield", new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").build()))); - x.addField(new Field("indexedtensorfield", + x.addField(new Field("indexedtensorfield", new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build()))); types.registerDocumentType(x); } @@ -1280,8 +1280,8 @@ public class JsonReaderTestCase { return (DocumentPut) reader.next(); } - private DocumentPut createPutWithMappedTensor(String inputTensor) { - return createPutWithTensor(inputTensor, "mappedtensorfield"); + private DocumentPut createPutWithMappedTensor(String inputTensor) { + return createPutWithTensor(inputTensor, "mappedtensorfield"); } private DocumentPut createPutWithTensor(String inputTensor, String tensorFieldName) { InputStream rawDoc = new ByteArrayInputStream( diff --git a/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java b/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java index 7104c1686f8..5c65b11a0c4 100644 --- a/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java +++ b/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java @@ -24,7 +24,7 @@ public class TensorFieldValueSerializationTestCase { private final static TensorType tensorType = new TensorType.Builder().mapped("dimX").mapped("dimY").build(); private final static String TENSOR_FIELD = "my_tensor"; private final static String TENSOR_FILES = "src/test/resources/tensor/"; - private final static TestDocumentFactory docFactory = new TestDocumentFactory(createDocType(), + private final static TestDocumentFactory docFactory = new TestDocumentFactory(createDocType(), "id:test:my_type::foo"); private static DocumentType createDocType() { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 785ed78492e..0eeb0a9e630 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.tensor.Tensor; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Set; @@ -18,26 +19,30 @@ public abstract class Context implements EvaluationContext { /** *

    Returns the value of a simple variable name.

    * - * @param name The name of the variable whose value to return. - * @return The value of the named variable. + * @param name the name of the variable whose value to return. + * @return the value of the named variable. */ public abstract Value get(String name); + /** Returns a variable as a tensor */ + @Override + public Tensor getTensor(String name) { return get(name).asTensor(); } + /** *

    Returns the value of a structured variable on the form * name(argument*)(.output)?, where argument is any * string. This may be used to implement more advanced variables whose * values are calculated at runtime from arguments. Supporting this in a - * context is optional. - * + * context is optional. + * *

    This default implementation generates a name on the form * name(argument1, argument2, ...argumentN).output. * If there are no arguments the parenthesis are omitted. * If there is no output, the dot is omitted.

    * - * @param name The name of this variable. - * @param arguments The parsed arguments as given in the textual expression. - * @param output The name of the value to output (to enable one named + * @param name the name of this variable. + * @param arguments the parsed arguments as given in the textual expression. + * @param output the name of the value to output (to enable one named * calculation to output several), or null to output the * "main" (or only) value. */ @@ -54,20 +59,20 @@ public abstract class Context implements EvaluationContext { * context subclasses. This default implementation throws * UnsupportedOperationException.

    * - * @param index The index of the variable whose value to return. - * @return The value of the indexed variable. + * @param index the index of the variable whose value to return. + * @return the value of the indexed variable. */ public Value get(int index) { throw new UnsupportedOperationException(this + " does not support variable lookup by index"); } /** - *

    Lookup by index rather than name directly to a double. This is supported by some optimized + * Lookup by index rather than name directly to a double. This is supported by some optimized * context subclasses. This default implementation throws - * UnsupportedOperationException.

    + * UnsupportedOperationException. * - * @param index The index of the variable whose value to return. - * @return The value of the indexed variable. + * @param index the index of the variable whose value to return. + * @return the value of the indexed variable. */ public double getDouble(int index) { throw new UnsupportedOperationException(this + " does not support variable lookup by index"); @@ -81,24 +86,23 @@ public abstract class Context implements EvaluationContext { } /** - *

    Sets a value to this, or throws an UnsupportedOperationException if - * this is not supported. This default implementation does the latter.

    * + * Sets a value to this, or throws an UnsupportedOperationException if + * this is not supported. This default implementation does the latter. * - * @param name The name of the variable to set. + * @param name the name of the variable to set. * @param value the value to set. Ownership of this value is transferred to this - if it is mutable * (not frozen) it may be modified during execution - * @since 5.1.5 */ public void put(String name, Value value) { throw new UnsupportedOperationException(this + " does not support variable assignment"); } /** - *

    Returns all the names available in this, or throws an + * Returns all the names available in this, or throws an * UnsupportedOperationException if this operation is not supported. This - * default implementation does the latter.

    + * default implementation does the latter. * - * @return The set of all variable names. + * @return the set of all variable names. */ public Set names() { throw new UnsupportedOperationException(this + " does not support return a list of its names"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index ea750295423..2ef4a2ede2f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -3,6 +3,9 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; /** * A value which acts as a double in numerical context. @@ -15,6 +18,11 @@ public abstract class DoubleCompatibleValue extends Value { @Override public boolean hasDouble() { return true; } + @Override + public Tensor asTensor() { + return doubleAsTensor(asDouble()); + } + @Override public Value negate() { return new DoubleValue(-asDouble()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index ac8aba6a617..dad69b31181 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -4,12 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; /** * A string value. * * @author bratseth - * @since 5.1.21 */ public class StringValue extends Value { @@ -34,6 +36,11 @@ public class StringValue extends Value { return UnicodeUtilities.unquote(value).hashCode(); } + @Override + public Tensor asTensor() { + return doubleAsTensor(asDouble()); + } + @Override public boolean hasDouble() { return true; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 49c3ccb7b01..26c30fe5ed2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -2,14 +2,10 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.google.common.annotations.Beta; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; -import com.yahoo.tensor.TensorType; - -import java.util.Collections; -import java.util.Optional; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; /** * A Value containing a tensor. @@ -23,7 +19,7 @@ public class TensorValue extends Value { /** The tensor value of this */ private final Tensor value; - + public TensorValue(Tensor value) { this.value = value; } @@ -131,7 +127,7 @@ public class TensorValue extends Value { public Value compare(TruthOperator operator, Value argument) { return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString()))); } - + private Tensor compareTensor(TruthOperator operator, Tensor argument) { switch (operator) { case LARGER: return value.larger(argument); @@ -152,7 +148,7 @@ public class TensorValue extends Value { else return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble()))); } - + private Tensor functionOnTensor(Function function, Tensor argument) { switch (function) { case min: return value.min(argument); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index b2ccbe572d0..40d70e0022c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -5,6 +5,8 @@ import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; /** * The result of a ranking expression evaluation. @@ -25,6 +27,14 @@ public abstract class Value { return new DoubleValue(asDouble()); } + /** Returns this as a tensor value */ + public abstract Tensor asTensor(); + + /** A utility method for wrapping a sdouble in a rank 0 tensor */ + protected Tensor doubleAsTensor(double value) { + return Tensor.Builder.of(TensorType.empty).cell(TensorAddress.of(), value).build(); + } + /** Returns true if this value can return itself as a double, i.e asDoubleValue will return a value and not throw */ public abstract boolean hasDouble(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java new file mode 100644 index 00000000000..b4a9b363ade --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java @@ -0,0 +1,51 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * The result of importing a TensorFlow model into Vespa: + * - A list of ranking expressions reproducing the computations of the outputs in the TensorFlow model + * - A list of named constant tensors + * - A list of expected input tensors, with their tensor type + * - A list of warning messages + * + * @author bratseth + */ +// This object can be built incrementally within this package, but is immutable when observed from outside the package +// TODO: Retain signature structure in ImportResult (input + output-expression bundles) +public class ImportResult { + + private final List expressions = new ArrayList<>(); + private final Map constants = new HashMap<>(); + private final Map arguments = new HashMap<>(); + private final List warnings = new ArrayList<>(); + + void add(RankingExpression expression) { expressions.add(expression); } + void set(String name, Tensor constant) { constants.put(name, constant); } + void set(String name, TensorType argument) { arguments.put(name, argument); } + void warn(String warning) { warnings.add(warning); } + + /** Returns an immutable list of the expressions of this */ + public List expressions() { return Collections.unmodifiableList(expressions); } + + /** Returns an immutable map of the constants of this */ + public Map constants() { return Collections.unmodifiableMap(constants); } + + /** Returns an immutable map of the arguments of this */ + public Map arguments() { return Collections.unmodifiableMap(arguments); } + + /** Returns an immutable list, in natural sort order of the warnings generated while importing this */ + public List warnings() { + return warnings.stream().sorted().collect(Collectors.toList()); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java deleted file mode 100644 index 235771bfa9c..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java +++ /dev/null @@ -1,23 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.tensor.Tensor; - -/** - * A tensor with a name - * - * @author bratseth - */ -public class NamedTensor { - - private final String name; - private final Tensor tensor; - - public NamedTensor(String name, Tensor tensor) { - this.name = name; - this.tensor = tensor; - } - - public String name() { return name; } - public Tensor tensor() { return tensor; } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index 183cfabbd87..e7f7b5ef2f4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -10,13 +10,14 @@ import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Matmul; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.Softmax; +import com.yahoo.tensor.functions.TensorFunction; import org.tensorflow.SavedModelBundle; import org.tensorflow.Session; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.NodeDef; +import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; @@ -45,16 +46,35 @@ class OperationMapper { private TensorConverter tensorConverter = new TensorConverter(); TypedTensorFunction join(List arguments, DoubleBinaryOperator doubleFunction) { - // Note that this generalizes the corresponding TF function as it does not verify that the tensor - // types are the same, with the assumption that this already happened on the TF side - // (and if not, this should do the right thing anyway) ensureArguments(2, arguments, "join"); TypedTensorFunction a = arguments.get(0); TypedTensorFunction b = arguments.get(1); + if (a.type().rank() < b.type().rank()) + throw new IllegalArgumentException("Attempt to join " + a.type() + " and " + b.type() + ", " + + "but this is not supported when the second argument has a higher rank"); + + TensorFunction bFunction = b.function(); + + if (a.type().rank() > b.type().rank()) { + // Well now we have entered the wonderful world of "broadcasting" + // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html + // I'm not able to extract from that any unambiguous specification of which dimensions + // should be "stretched" when the tensor do not have the same number of dimensions. + // From trying this with TensorFlow it appears that the second tensor is matched to the + // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. + // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). + List renameFrom = new ArrayList<>(); + List renameTo = new ArrayList<>(); + int sizeDifference = a.type().rank() - b.type().rank(); + for (int i = 0; i < b.type().rank(); i++) { + renameFrom.add(b.type().dimensions().get(i).name()); + renameTo.add("d" + (sizeDifference + i)); + } + bFunction = new Rename(bFunction, renameFrom, renameTo); + } - TensorType resultType = Join.outputType(a.type(), b.type()); - Join function = new Join(a.function(), b.function(), doubleFunction); - return new TypedTensorFunction(resultType, function); + Join function = new Join(a.function(), bFunction, doubleFunction); + return new TypedTensorFunction(a.type(), function); // output type is a type by TF definition and a.rank>=b.rank } TypedTensorFunction map(List arguments, DoubleUnaryOperator doubleFunction) { @@ -66,35 +86,37 @@ class OperationMapper { return new TypedTensorFunction(resultType, function); } - TypedTensorFunction identity(NodeDef tfNode, Map inputs, SavedModelBundle model, - List constants) { - String name; - TensorType type; - if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model - if (tfNode.getInputList().size() != 1) - throw new IllegalArgumentException("A Variable/read node must have one input but has " + - tfNode.getInputList().size()); - name = tfNode.getInput(0); - AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); - if (shapes == null) - throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape"); - Session.Runner fetched = model.session().runner().fetch(name); - List> result = fetched.run(); - if ( result.size() != 1) - throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + result.size()); - Tensor constant = tensorConverter.toVespaTensor(result.get(0)); - constants.add(new NamedTensor(name, constant)); - return new TypedTensorFunction(constant.type(), - new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")"))); - } - else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name - name = tfNode.getName(); - type = inputs.get(name); - if (type == null) - throw new IllegalArgumentException("An identity operation node is referencing input '" + name + - "', but there is no such input"); - return new TypedTensorFunction(type, new VariableTensor(name)); - } + TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) { + String name = tfNode.getName(); + TensorType type = result.arguments().get(name); + if (type == null) + throw new IllegalArgumentException("An placeholder operation node is referencing input '" + name + + "', but there is no such input"); + // Included literally in the expression and so must be produced by a separate macro in the rank profile + return new TypedTensorFunction(type, new VariableTensor(name)); + } + + TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) { + if ( ! tfNode.getName().endsWith("/read")) + throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " + + "nodes are only supported when reading variables"); + if (tfNode.getInputList().size() != 1) + throw new IllegalArgumentException("A Variable/read node must have one input but has " + + tfNode.getInputList().size()); + + String name = tfNode.getInput(0); + AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); + if (shapes == null) + throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape"); + Session.Runner fetched = model.session().runner().fetch(name); + List> importedTensors = fetched.run(); + if ( importedTensors.size() != 1) + throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + + importedTensors.size()); + Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0)); + result.set(name, constant); + return new TypedTensorFunction(constant.type(), + new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")"))); } TypedTensorFunction matmul(List arguments) { @@ -106,21 +128,18 @@ class OperationMapper { if (a.type().rank() != b.type().rank()) throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - // Let the second-to-last dimension of the second tensor be the same as the last dimension of the first - // and the last dimension of the second argument be not present in the first argument, while leaving the + String afterLastDim = "d" + (a.type().rank() + 1); + // Let the first dimension of the second tensor be the same as the second dimension of the first + // and the second dimension of the second argument be not present in the first argument, while leaving the // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication. - // TODO: Check if transpose_a or transpose_b is set and rename differently accordingly - - String beforeLastDim = "d" + (a.type().rank() - 1); - String lastDim = "d" + a.type().rank(); - String afterLastDim = "d" + (a.type().rank() + 1); + // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly - Rename renamedB = new Rename(b.function(), ImmutableList.of(beforeLastDim, lastDim), - ImmutableList.of(lastDim, afterLastDim)); - Matmul matmul = new Matmul(a.function(), renamedB, lastDim); - return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), lastDim), - new Rename(matmul, afterLastDim, lastDim)); + Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"), + ImmutableList.of("d1", afterLastDim)); + Matmul matmul = new Matmul(a.function(), renamedB, "d1"); + return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"), + new Rename(matmul, afterLastDim, "d1")); } TypedTensorFunction softmax(List arguments) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java index a74445008b7..df43225c333 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java @@ -24,8 +24,10 @@ public class TensorConverter { private TensorType toVespaTensorType(long[] shape) { TensorType.Builder b = new TensorType.Builder(); int dimensionIndex = 0; - for (long dimensionSize : shape) - b.indexed("d" + (dimensionIndex++), (int)dimensionSize); + for (long dimensionSize : shape) { + if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... + b.indexed("d" + (dimensionIndex++), (int) dimensionSize); + } return b.build(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 51f1e444e70..33523244129 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -16,11 +16,8 @@ import org.tensorflow.framework.TensorInfo; import org.tensorflow.framework.TensorShapeProto; import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.logging.Level; import java.util.stream.Collectors; /** @@ -35,104 +32,100 @@ public class TensorFlowImporter { /** * Imports a saved TensorFlow model from a directory. * The model should be saved as a pbtxt file. - * The name of the model is taken at the pbtxt file name (not including the .pbtxt ending). + * The name of the model is taken as the db/pbtxt file name (not including the file ending). * * @param modelDir the directory containing the TensorFlow model files to import - * @param constants any constant tensors imported from the TensorFlow model and referenced in the returned expressions - * @param logger a receiver of any messages generated by the import process - * @return the ranking expressions resulting from importing this TenorFlow model */ - public List importModel(String modelDir, List constants, MessageLogger logger) { - try { - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model, constants, logger); + public ImportResult importModel(String modelDir) { + try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { + return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model); } catch (IOException e) { - throw new IllegalArgumentException("Could not open TensorFlow model directory '" + modelDir + "'", e); + throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e); } + } + public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) { + try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { + MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef()); + SignatureDef signature = graph.getSignatureDefMap().get(inputSignatureName); + ImportResult result = new ImportResult(); + importInputs(signature.getInputsMap(), result); + result.add(new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result))); + return result; + } + catch (IOException e) { + throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e); + } } - private List importGraph(MetaGraphDef graph, SavedModelBundle model, - List constants, MessageLogger logger) { - List expressions = new ArrayList<>(); + private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) { + ImportResult result = new ImportResult(); for (Map.Entry signatureEntry : graph.getSignatureDefMap().entrySet()) { - Map inputs = importInputs(signatureEntry.getValue().getInputsMap()); + importInputs(signatureEntry.getValue().getInputsMap(), result); for (Map.Entry output : signatureEntry.getValue().getOutputsMap().entrySet()) { try { - ExpressionNode result = importOutput(output.getValue(), - inputs, - graph.getGraphDef(), - model, - constants); - expressions.add(new RankingExpression(output.getKey(), result)); + ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result); + result.add(new RankingExpression(output.getKey(), node)); } catch (IllegalArgumentException e) { - logger.log(Level.INFO, "Skipping output '" + output.getValue().getName() + "' of signature '" + - signatureEntry.getValue().getMethodName() + - "': " + Exceptions.toMessageString(e)); + result.warn("Skipping output '" + output.getValue().getName() + "' of signature '" + + signatureEntry.getValue().getMethodName() + + "': " + Exceptions.toMessageString(e)); } } } - return expressions; + return result; } - private Map importInputs(Map inputInfoMap) { - Map inputs = new HashMap<>(); - inputInfoMap.forEach((key, value) -> inputs.put(nameOf(value.getName()), + private void importInputs(Map inputInfoMap, ImportResult result) { + inputInfoMap.forEach((key, value) -> result.set(nameOf(value.getName()), importTensorType(value.getTensorShape()))); - return inputs; } private TensorType importTensorType(TensorShapeProto tensorShape) { TensorType.Builder b = new TensorType.Builder(); - for (int i = 0; i < tensorShape.getDimCount(); i++) { - int dimensionSize = (int) tensorShape.getDim(i).getSize(); + for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) { + int dimensionSize = (int)dimension.getSize(); if (dimensionSize >= 0) - b.indexed("d" + i, dimensionSize); + b.indexed("d" + b.rank(), dimensionSize); else - b.indexed("d" + i); // unbound size + b.indexed("d" + b.rank()); // unbound size } return b.build(); } - private ExpressionNode importOutput(TensorInfo output, Map inputs, GraphDef graph, - SavedModelBundle model, List constants) { - NodeDef node = getNode(nameOf(output.getName()), graph); - TensorFunction function = importNode(node, inputs, graph, model, constants).function(); + private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) { + return importNode(nameOf(output.getName()), graph, model, result); + } + + private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) { + TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function(); return new TensorFunctionNode(function); // wrap top level (only) as an expression } /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private TypedTensorFunction importNode(NodeDef tfNode, Map inputs, GraphDef graph, - SavedModelBundle model, List constants) { - return tensorFunctionOf(tfNode, inputs, graph, model, constants); + private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + return tensorFunctionOf(tfNode, graph, model, result); } - private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, - Map inputs, - GraphDef graph, - SavedModelBundle model, - List constants) { + private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ switch (tfNode.getOp().toLowerCase()) { - case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, model, constants), ScalarFunctions.add()); - case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, model, constants), ScalarFunctions.acos()); - case "identity" : return operationMapper.identity(tfNode, inputs, model, constants); - case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, model, constants)); - case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, model, constants)); + case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add()); + case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos()); + case "placeholder" : return operationMapper.placeholder(tfNode, result); + case "identity" : return operationMapper.identity(tfNode, model, result); + case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result)); + case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result)); default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); } } - private List importArguments(NodeDef tfNode, - Map inputs, - GraphDef graph, - SavedModelBundle model, - List constants) { + private List importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { return tfNode.getInputList().stream() - .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, model, constants)) + .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result)) .collect(Collectors.toList()); } @@ -151,11 +144,4 @@ public class TensorFlowImporter { return name.split(":")[0]; } - /** An interface which can be implemented to receive messages emitted during import */ - public interface MessageLogger { - - void log(Level level, String message); - - } - } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java index 234d620d02f..5712da77700 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java @@ -3,9 +3,9 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; -/** +/** * A tensor function returning a specific tensor type - * + * * @author bratseth */ final class TypedTensorFunction { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java index 71699b379b2..d366c9bfbe5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -14,23 +14,23 @@ import java.util.function.*; /** * A tensor generating function, whose arguments are determined by a tensor type - * + * * @author bratseth */ public class GeneratorLambdaFunctionNode extends CompositeNode { private final TensorType type; private final ExpressionNode generator; - + public GeneratorLambdaFunctionNode(TensorType type, ExpressionNode generator) { if ( ! type.dimensions().stream().allMatch(d -> d.size().isPresent())) - throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " + + throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " + "dimensions, but tried to generate " + type); // TODO: Verify that the function only accesses the given arguments this.type = type; this.generator = generator; } - + @Override public List children() { return Collections.singletonList(generator); @@ -53,8 +53,8 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { public Value evaluate(Context context) { return generator.evaluate(context); } - - /** + + /** * Returns this as an operator which converts a list of integers into a double */ public IntegerListToDoubleLambda asIntegerListToDoubleOperator() { @@ -70,7 +70,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { context.put(type.dimensions().get(i).name(), arguments.get(i)); return evaluate(context).asDouble(); } - + @Override public String toString() { return GeneratorLambdaFunctionNode.this.toString(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index d1f4cbddf6e..8af3448ca6f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -36,10 +36,17 @@ public class TensorFunctionNode extends CompositeNode { @Override public List children() { return function.functionArguments().stream() - .map(f -> ((TensorFunctionExpressionNode)f).expression) + .map(this::toExpressionNode) .collect(Collectors.toList()); } + private ExpressionNode toExpressionNode(TensorFunction f) { + if (f instanceof TensorFunctionExpressionNode) + return ((TensorFunctionExpressionNode)f).expression; + else + return new TensorFunctionNode(f); + } + @Override public CompositeNode setChildren(List children) { List wrappedChildren = children.stream() diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py new file mode 100644 index 00000000000..a1861a1c981 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py @@ -0,0 +1,89 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A very simple MNIST classifier. + +See extensive documentation at +https://www.tensorflow.org/get_started/mnist/beginners +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +from tensorflow.examples.tutorials.mnist import input_data + +import tensorflow as tf + +FLAGS = None + + +def main(_): + # Import data + mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) + + # Create the model + x = tf.placeholder(tf.float32, [None, 784]) + W = tf.Variable(tf.zeros([784, 10])) + b = tf.Variable(tf.zeros([10])) + y = tf.matmul(x, W) + b + + # Define loss and optimizer + y_ = tf.placeholder(tf.float32, [None, 10]) + + # The raw formulation of cross-entropy, + # + # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), + # reduction_indices=[1])) + # + # can be numerically unstable. + # + # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw + # outputs of 'y', and then average across the batch. + cross_entropy = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) + train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) + + sess = tf.InteractiveSession() + tf.global_variables_initializer().run() + # Train + for _ in range(1000): + batch_xs, batch_ys = mnist.train.next_batch(100) + sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) + + # Test trained model + correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + print(sess.run(accuracy, feed_dict={x: mnist.test.images, + y_: mnist.test.labels})) + + # Save the model + export_path = "saved" + print('Exporting trained model to ', export_path) + builder = tf.saved_model.builder.SavedModelBuilder(export_path) + signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':x}, outputs = {'y':y}) + builder.add_meta_graph_and_variables(sess, + [tf.saved_model.tag_constants.SERVING], + signature_def_map={'serving_default':signature}) + builder.save(as_text=True) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', + help='Directory for storing input data') + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt new file mode 100644 index 00000000000..8100dfd594d --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt @@ -0,0 +1,5039 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Add" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } + } + op { + name: "ApplyGradientDescent" + input_arg { + name: "var" + type_attr: "T" + is_ref: true + } + input_arg { + name: "alpha" + type_attr: "T" + } + input_arg { + name: "delta" + type_attr: "T" + } + output_arg { + name: "out" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: false + } + } + } + op { + name: "ArgMax" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dimension" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "output_type" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "output_type" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Assign" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + } + attr { + name: "validate_shape" + type: "bool" + default_value { + b: true + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: true + } + } + allows_uninitialized_input: true + } + op { + name: "BroadcastGradientArgs" + input_arg { + name: "s0" + type_attr: "T" + } + input_arg { + name: "s1" + type_attr: "T" + } + output_arg { + name: "r0" + type_attr: "T" + } + output_arg { + name: "r1" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Cast" + input_arg { + name: "x" + type_attr: "SrcT" + } + output_arg { + name: "y" + type_attr: "DstT" + } + attr { + name: "SrcT" + type: "type" + } + attr { + name: "DstT" + type: "type" + } + } + op { + name: "ConcatV2" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + input_arg { + name: "axis" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 2 + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "Equal" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type: DT_BOOL + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_QUINT8 + type: DT_QINT8 + type: DT_QINT32 + type: DT_STRING + type: DT_BOOL + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "ExpandDims" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dim" + type_attr: "Tdim" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tdim" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Fill" + input_arg { + name: "dims" + type: DT_INT32 + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "FloorDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Identity" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "MatMul" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "b" + type_attr: "T" + } + output_arg { + name: "product" + type_attr: "T" + } + attr { + name: "transpose_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "transpose_b" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Maximum" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_commutative: true + } + op { + name: "Mean" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "MergeV2Checkpoints" + input_arg { + name: "checkpoint_prefixes" + type: DT_STRING + } + input_arg { + name: "destination_prefix" + type: DT_STRING + } + attr { + name: "delete_old_dirs" + type: "bool" + default_value { + b: true + } + } + is_stateful: true + } + op { + name: "Mul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "NoOp" + } + op { + name: "Pack" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "axis" + type: "int" + default_value { + i: 0 + } + } + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "Prod" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RealDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Reshape" + input_arg { + name: "tensor" + type_attr: "T" + } + input_arg { + name: "shape" + type_attr: "Tshape" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshape" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RestoreV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + output_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "SaveV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + input_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "Shape" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "ShardedFilename" + input_arg { + name: "basename" + type: DT_STRING + } + input_arg { + name: "shard" + type: DT_INT32 + } + input_arg { + name: "num_shards" + type: DT_INT32 + } + output_arg { + name: "filename" + type: DT_STRING + } + } + op { + name: "Slice" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "begin" + type_attr: "Index" + } + input_arg { + name: "size" + type_attr: "Index" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Index" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "SoftmaxCrossEntropyWithLogits" + input_arg { + name: "features" + type_attr: "T" + } + input_arg { + name: "labels" + type_attr: "T" + } + output_arg { + name: "loss" + type_attr: "T" + } + output_arg { + name: "backprop" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + } + op { + name: "StringJoin" + input_arg { + name: "inputs" + type: DT_STRING + number_attr: "N" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "separator" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "Sub" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Sum" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Tile" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "multiples" + type_attr: "Tmultiples" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tmultiples" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "VariableV2" + output_arg { + name: "ref" + type_attr: "dtype" + is_ref: true + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true + } + op { + name: "ZerosLike" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + } + tags: "serve" + tensorflow_version: "1.4.1" + tensorflow_git_version: "v1.4.0-19-ga52c8d9b01" + } + graph_def { + node { + name: "Placeholder" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + node { + name: "zeros" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "Variable" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "Variable/Assign" + op: "Assign" + input: "Variable" + input: "zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "Variable/read" + op: "Identity" + input: "Variable" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "zeros_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "Variable_1" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "Variable_1/Assign" + op: "Assign" + input: "Variable_1" + input: "zeros_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "Variable_1/read" + op: "Identity" + input: "Variable_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "MatMul" + op: "MatMul" + input: "Placeholder" + input: "Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "add" + op: "Add" + input: "MatMul" + input: "Variable_1/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "Placeholder_1" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + node { + name: "Rank" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape" + op: "Shape" + input: "add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Rank_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape_1" + op: "Shape" + input: "add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Sub/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice/begin" + op: "Pack" + input: "Sub" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice/size" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "concat/values_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat/axis" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + } + node { + name: "Reshape" + op: "Reshape" + input: "add" + input: "concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Rank_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape_2" + op: "Shape" + input: "Placeholder_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Sub_1/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice_1/size" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "concat_1/values_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat_1/axis" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + } + node { + name: "Reshape_1" + op: "Reshape" + input: "Placeholder_1" + input: "concat_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape" + input: "Reshape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Sub_2/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice_2/begin" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Reshape_2" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Mean" + op: "Mean" + input: "Reshape_2" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "gradients/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape_1" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Mean_grad/Shape_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Mean_grad/Const_1" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_2" + input: "gradients/Mean_grad/Const_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Mean_grad/Maximum/y" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Reshape_2_grad/Shape" + op: "Shape" + input: "SoftmaxCrossEntropyWithLogits" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Reshape_2_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_2_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/zeros_like" + op: "ZerosLike" + input: "SoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_2_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "SoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Reshape_grad/Shape" + op: "Shape" + input: "add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Reshape_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/add_grad/Shape" + op: "Shape" + input: "MatMul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/add_grad/Shape_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } + } + node { + name: "gradients/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_grad/Shape" + input: "gradients/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/add_grad/Sum" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/add_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/add_grad/Reshape" + op: "Reshape" + input: "gradients/add_grad/Sum" + input: "gradients/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/add_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_grad/Sum_1" + input: "gradients/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_grad/Reshape" + input: "^gradients/add_grad/Reshape_1" + } + node { + name: "gradients/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_grad/Reshape" + input: "^gradients/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_grad/Reshape_1" + input: "^gradients/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/add_grad/tuple/control_dependency" + input: "Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } + } + node { + name: "gradients/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Placeholder" + input: "gradients/add_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "gradients/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/MatMul_1" + } + node { + name: "gradients/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "gradients/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_grad/MatMul_1" + input: "^gradients/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "GradientDescent/learning_rate" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + } + node { + name: "GradientDescent/update_Variable/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "Variable" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent/update_Variable_1/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "Variable_1" + input: "GradientDescent/learning_rate" + input: "gradients/add_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_Variable/ApplyGradientDescent" + input: "^GradientDescent/update_Variable_1/ApplyGradientDescent" + } + node { + name: "init" + op: "NoOp" + input: "^Variable/Assign" + input: "^Variable_1/Assign" + } + node { + name: "ArgMax/dimension" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "ArgMax" + op: "ArgMax" + input: "add" + input: "ArgMax/dimension" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_type" + value { + type: DT_INT64 + } + } + } + node { + name: "ArgMax_1/dimension" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "ArgMax_1" + op: "ArgMax" + input: "Placeholder_1" + input: "ArgMax_1/dimension" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_type" + value { + type: DT_INT64 + } + } + } + node { + name: "Equal" + op: "Equal" + input: "ArgMax" + input: "ArgMax_1" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Cast_1" + op: "Cast" + input: "Equal" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_BOOL + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Const_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Mean_1" + op: "Mean" + input: "Cast_1" + input: "Const_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "save/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "model" + } + } + } + } + node { + name: "save/StringJoin/inputs_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "_temp_6ca9fa5171ed4237a2fbcc27277e2864/part" + } + } + } + } + node { + name: "save/StringJoin" + op: "StringJoin" + input: "save/Const" + input: "save/StringJoin/inputs_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "separator" + value { + s: "" + } + } + } + node { + name: "save/num_shards" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "save/ShardedFilename/shard" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "save/ShardedFilename" + op: "ShardedFilename" + input: "save/StringJoin" + input: "save/ShardedFilename/shard" + input: "save/num_shards" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/SaveV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "Variable" + string_val: "Variable_1" + } + } + } + } + node { + name: "save/SaveV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "" + string_val: "" + } + } + } + } + node { + name: "save/SaveV2" + op: "SaveV2" + input: "save/ShardedFilename" + input: "save/SaveV2/tensor_names" + input: "save/SaveV2/shape_and_slices" + input: "Variable" + input: "Variable_1" + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + } + node { + name: "save/control_dependency" + op: "Identity" + input: "save/ShardedFilename" + input: "^save/SaveV2" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@save/ShardedFilename" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/MergeV2Checkpoints/checkpoint_prefixes" + op: "Pack" + input: "save/ShardedFilename" + input: "^save/control_dependency" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "save/MergeV2Checkpoints" + op: "MergeV2Checkpoints" + input: "save/MergeV2Checkpoints/checkpoint_prefixes" + input: "save/Const" + attr { + key: "delete_old_dirs" + value { + b: true + } + } + } + node { + name: "save/Identity" + op: "Identity" + input: "save/Const" + input: "^save/control_dependency" + input: "^save/MergeV2Checkpoints" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/RestoreV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "Variable" + } + } + } + } + node { + name: "save/RestoreV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2/tensor_names" + input: "save/RestoreV2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign" + op: "Assign" + input: "Variable" + input: "save/RestoreV2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_1/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "Variable_1" + } + } + } + } + node { + name: "save/RestoreV2_1/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_1" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_1/tensor_names" + input: "save/RestoreV2_1/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_1" + op: "Assign" + input: "Variable_1" + input: "save/RestoreV2_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/restore_shard" + op: "NoOp" + input: "^save/Assign" + input: "^save/Assign_1" + } + node { + name: "save/restore_all" + op: "NoOp" + input: "^save/restore_shard" + } + versions { + producer: 24 + } + } + saver_def { + filename_tensor_name: "save/Const:0" + save_tensor_name: "save/Identity:0" + restore_op_name: "save/restore_all" + max_to_keep: 5 + sharded: true + keep_checkpoint_every_n_hours: 10000.0 + version: V2 + } + collection_def { + key: "train_op" + value { + node_list { + value: "GradientDescent" + } + } + } + collection_def { + key: "trainable_variables" + value { + bytes_list { + value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0" + value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0" + } + } + } + collection_def { + key: "variables" + value { + bytes_list { + value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0" + value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0" + } + } + } + signature_def { + key: "serving_default" + value { + inputs { + key: "x" + value { + name: "Placeholder:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + outputs { + key: "y" + value { + name: "add:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } +} diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..8474aa0a04c Binary files /dev/null and b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 differ diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index new file mode 100644 index 00000000000..cfcdac20409 Binary files /dev/null and b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index differ diff --git a/searchlib/src/test/files/integration/tensorflow/model1/saved_model.pbtxt b/searchlib/src/test/files/integration/tensorflow/model1/saved_model.pbtxt deleted file mode 100644 index e01688669a1..00000000000 --- a/searchlib/src/test/files/integration/tensorflow/model1/saved_model.pbtxt +++ /dev/null @@ -1,4909 +0,0 @@ -saved_model_schema_version: 1 -meta_graphs { - meta_info_def { - stripped_op_list { - op { - name: "Add" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_UINT8 - type: DT_INT8 - type: DT_INT16 - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_STRING - } - } - } - } - op { - name: "ApplyGradientDescent" - input_arg { - name: "var" - type_attr: "T" - is_ref: true - } - input_arg { - name: "alpha" - type_attr: "T" - } - input_arg { - name: "delta" - type_attr: "T" - } - output_arg { - name: "out" - type_attr: "T" - is_ref: true - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "use_locking" - type: "bool" - default_value { - b: false - } - } - } - op { - name: "ArgMax" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "dimension" - type_attr: "Tidx" - } - output_arg { - name: "output" - type_attr: "output_type" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - attr { - name: "output_type" - type: "type" - default_value { - type: DT_INT64 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "Assign" - input_arg { - name: "ref" - type_attr: "T" - is_ref: true - } - input_arg { - name: "value" - type_attr: "T" - } - output_arg { - name: "output_ref" - type_attr: "T" - is_ref: true - } - attr { - name: "T" - type: "type" - } - attr { - name: "validate_shape" - type: "bool" - default_value { - b: true - } - } - attr { - name: "use_locking" - type: "bool" - default_value { - b: true - } - } - allows_uninitialized_input: true - } - op { - name: "BroadcastGradientArgs" - input_arg { - name: "s0" - type_attr: "T" - } - input_arg { - name: "s1" - type_attr: "T" - } - output_arg { - name: "r0" - type_attr: "T" - } - output_arg { - name: "r1" - type_attr: "T" - } - attr { - name: "T" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "Cast" - input_arg { - name: "x" - type_attr: "SrcT" - } - output_arg { - name: "y" - type_attr: "DstT" - } - attr { - name: "SrcT" - type: "type" - } - attr { - name: "DstT" - type: "type" - } - } - op { - name: "Const" - output_arg { - name: "output" - type_attr: "dtype" - } - attr { - name: "value" - type: "tensor" - } - attr { - name: "dtype" - type: "type" - } - } - op { - name: "Equal" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type: DT_BOOL - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_UINT8 - type: DT_INT8 - type: DT_INT16 - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_QUINT8 - type: DT_QINT8 - type: DT_QINT32 - type: DT_STRING - type: DT_BOOL - type: DT_COMPLEX128 - } - } - } - is_commutative: true - } - op { - name: "Fill" - input_arg { - name: "dims" - type: DT_INT32 - } - input_arg { - name: "value" - type_attr: "T" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - } - op { - name: "HashTableV2" - output_arg { - name: "table_handle" - type: DT_RESOURCE - } - attr { - name: "container" - type: "string" - default_value { - s: "" - } - } - attr { - name: "shared_name" - type: "string" - default_value { - s: "" - } - } - attr { - name: "use_node_name_sharing" - type: "bool" - default_value { - b: false - } - } - attr { - name: "key_dtype" - type: "type" - } - attr { - name: "value_dtype" - type: "type" - } - is_stateful: true - } - op { - name: "Identity" - input_arg { - name: "input" - type_attr: "T" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - } - op { - name: "InitializeTableV2" - input_arg { - name: "table_handle" - type: DT_RESOURCE - } - input_arg { - name: "keys" - type_attr: "Tkey" - } - input_arg { - name: "values" - type_attr: "Tval" - } - attr { - name: "Tkey" - type: "type" - } - attr { - name: "Tval" - type: "type" - } - is_stateful: true - } - op { - name: "Log" - input_arg { - name: "x" - type_attr: "T" - } - output_arg { - name: "y" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - } - op { - name: "LookupTableFindV2" - input_arg { - name: "table_handle" - type: DT_RESOURCE - } - input_arg { - name: "keys" - type_attr: "Tin" - } - input_arg { - name: "default_value" - type_attr: "Tout" - } - output_arg { - name: "values" - type_attr: "Tout" - } - attr { - name: "Tin" - type: "type" - } - attr { - name: "Tout" - type: "type" - } - is_stateful: true - } - op { - name: "MatMul" - input_arg { - name: "a" - type_attr: "T" - } - input_arg { - name: "b" - type_attr: "T" - } - output_arg { - name: "product" - type_attr: "T" - } - attr { - name: "transpose_a" - type: "bool" - default_value { - b: false - } - } - attr { - name: "transpose_b" - type: "bool" - default_value { - b: false - } - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT32 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - } - op { - name: "Mean" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "reduction_indices" - type_attr: "Tidx" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "keep_dims" - type: "bool" - default_value { - b: false - } - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "MergeV2Checkpoints" - input_arg { - name: "checkpoint_prefixes" - type: DT_STRING - } - input_arg { - name: "destination_prefix" - type: DT_STRING - } - attr { - name: "delete_old_dirs" - type: "bool" - default_value { - b: true - } - } - is_stateful: true - } - op { - name: "Mul" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_UINT8 - type: DT_INT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - is_commutative: true - } - op { - name: "Neg" - input_arg { - name: "x" - type_attr: "T" - } - output_arg { - name: "y" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - } - op { - name: "NoOp" - } - op { - name: "Pack" - input_arg { - name: "values" - type_attr: "T" - number_attr: "N" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "N" - type: "int" - has_minimum: true - minimum: 1 - } - attr { - name: "T" - type: "type" - } - attr { - name: "axis" - type: "int" - default_value { - i: 0 - } - } - } - op { - name: "ParseExample" - input_arg { - name: "serialized" - type: DT_STRING - } - input_arg { - name: "names" - type: DT_STRING - } - input_arg { - name: "sparse_keys" - type: DT_STRING - number_attr: "Nsparse" - } - input_arg { - name: "dense_keys" - type: DT_STRING - number_attr: "Ndense" - } - input_arg { - name: "dense_defaults" - type_list_attr: "Tdense" - } - output_arg { - name: "sparse_indices" - type: DT_INT64 - number_attr: "Nsparse" - } - output_arg { - name: "sparse_values" - type_list_attr: "sparse_types" - } - output_arg { - name: "sparse_shapes" - type: DT_INT64 - number_attr: "Nsparse" - } - output_arg { - name: "dense_values" - type_list_attr: "Tdense" - } - attr { - name: "Nsparse" - type: "int" - has_minimum: true - } - attr { - name: "Ndense" - type: "int" - has_minimum: true - } - attr { - name: "sparse_types" - type: "list(type)" - has_minimum: true - allowed_values { - list { - type: DT_FLOAT - type: DT_INT64 - type: DT_STRING - } - } - } - attr { - name: "Tdense" - type: "list(type)" - has_minimum: true - allowed_values { - list { - type: DT_FLOAT - type: DT_INT64 - type: DT_STRING - } - } - } - attr { - name: "dense_shapes" - type: "list(shape)" - has_minimum: true - } - } - op { - name: "Placeholder" - output_arg { - name: "output" - type_attr: "dtype" - } - attr { - name: "dtype" - type: "type" - } - attr { - name: "shape" - type: "shape" - default_value { - shape { - unknown_rank: true - } - } - } - } - op { - name: "Range" - input_arg { - name: "start" - type_attr: "Tidx" - } - input_arg { - name: "limit" - type_attr: "Tidx" - } - input_arg { - name: "delta" - type_attr: "Tidx" - } - output_arg { - name: "output" - type_attr: "Tidx" - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "Reciprocal" - input_arg { - name: "x" - type_attr: "T" - } - output_arg { - name: "y" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - } - op { - name: "Reshape" - input_arg { - name: "tensor" - type_attr: "T" - } - input_arg { - name: "shape" - type_attr: "Tshape" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - attr { - name: "Tshape" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "RestoreV2" - input_arg { - name: "prefix" - type: DT_STRING - } - input_arg { - name: "tensor_names" - type: DT_STRING - } - input_arg { - name: "shape_and_slices" - type: DT_STRING - } - output_arg { - name: "tensors" - type_list_attr: "dtypes" - } - attr { - name: "dtypes" - type: "list(type)" - has_minimum: true - minimum: 1 - } - is_stateful: true - } - op { - name: "SaveV2" - input_arg { - name: "prefix" - type: DT_STRING - } - input_arg { - name: "tensor_names" - type: DT_STRING - } - input_arg { - name: "shape_and_slices" - type: DT_STRING - } - input_arg { - name: "tensors" - type_list_attr: "dtypes" - } - attr { - name: "dtypes" - type: "list(type)" - has_minimum: true - minimum: 1 - } - is_stateful: true - } - op { - name: "Shape" - input_arg { - name: "input" - type_attr: "T" - } - output_arg { - name: "output" - type_attr: "out_type" - } - attr { - name: "T" - type: "type" - } - attr { - name: "out_type" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "ShardedFilename" - input_arg { - name: "basename" - type: DT_STRING - } - input_arg { - name: "shard" - type: DT_INT32 - } - input_arg { - name: "num_shards" - type: DT_INT32 - } - output_arg { - name: "filename" - type: DT_STRING - } - } - op { - name: "Softmax" - input_arg { - name: "logits" - type_attr: "T" - } - output_arg { - name: "softmax" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - } - } - } - } - op { - name: "StringJoin" - input_arg { - name: "inputs" - type: DT_STRING - number_attr: "N" - } - output_arg { - name: "output" - type: DT_STRING - } - attr { - name: "N" - type: "int" - has_minimum: true - minimum: 1 - } - attr { - name: "separator" - type: "string" - default_value { - s: "" - } - } - } - op { - name: "Sub" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - } - op { - name: "Sum" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "reduction_indices" - type_attr: "Tidx" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "keep_dims" - type: "bool" - default_value { - b: false - } - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "Tile" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "multiples" - type_attr: "Tmultiples" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - attr { - name: "Tmultiples" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "TopKV2" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "k" - type: DT_INT32 - } - output_arg { - name: "values" - type_attr: "T" - } - output_arg { - name: "indices" - type: DT_INT32 - } - attr { - name: "sorted" - type: "bool" - default_value { - b: true - } - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 - type: DT_UINT8 - type: DT_INT16 - type: DT_INT8 - type: DT_UINT16 - type: DT_HALF - } - } - } - } - op { - name: "VariableV2" - output_arg { - name: "ref" - type_attr: "dtype" - is_ref: true - } - attr { - name: "shape" - type: "shape" - } - attr { - name: "dtype" - type: "type" - } - attr { - name: "container" - type: "string" - default_value { - s: "" - } - } - attr { - name: "shared_name" - type: "string" - default_value { - s: "" - } - } - is_stateful: true - } - } - tags: "serve" - tensorflow_version: "1.3.0" - tensorflow_git_version: "v1.3.0-rc2-20-g0787eee" - } - graph_def { - node { - name: "tf_example" - op: "Placeholder" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "shape" - value { - shape { - unknown_rank: true - } - } - } - } - node { - name: "ParseExample/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - } - } - } - } - } - } - node { - name: "ParseExample/ParseExample/names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - } - } - } - } - } - } - node { - name: "ParseExample/ParseExample/dense_keys_0" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "x" - } - } - } - } - node { - name: "ParseExample/ParseExample" - op: "ParseExample" - input: "tf_example" - input: "ParseExample/ParseExample/names" - input: "ParseExample/ParseExample/dense_keys_0" - input: "ParseExample/Const" - attr { - key: "Ndense" - value { - i: 1 - } - } - attr { - key: "Nsparse" - value { - i: 0 - } - } - attr { - key: "Tdense" - value { - list { - type: DT_FLOAT - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - } - } - } - attr { - key: "dense_shapes" - value { - list { - shape { - dim { - size: 784 - } - } - } - } - } - attr { - key: "sparse_types" - value { - list { - } - } - } - } - node { - name: "x" - op: "Identity" - input: "ParseExample/ParseExample" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - } - } - } - } - node { - name: "Placeholder" - op: "Placeholder" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - node { - name: "zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - float_val: 0.0 - } - } - } - } - node { - name: "Variable" - op: "VariableV2" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } - } - node { - name: "Variable/Assign" - op: "Assign" - input: "Variable" - input: "zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "Variable/read" - op: "Identity" - input: "Variable" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "zeros_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 10 - } - } - float_val: 0.0 - } - } - } - } - node { - name: "Variable_1" - op: "VariableV2" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } - } - node { - name: "Variable_1/Assign" - op: "Assign" - input: "Variable_1" - input: "zeros_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "Variable_1/read" - op: "Identity" - input: "Variable_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - } - node { - name: "init" - op: "NoOp" - input: "^Variable/Assign" - input: "^Variable_1/Assign" - } - node { - name: "MatMul" - op: "MatMul" - input: "x" - input: "Variable/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: false - } - } - } - node { - name: "add" - op: "Add" - input: "MatMul" - input: "Variable_1/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "y" - op: "Softmax" - input: "add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "Log" - op: "Log" - input: "y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "mul" - op: "Mul" - input: "Placeholder" - input: "Log" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } - } - node { - name: "Sum" - op: "Sum" - input: "mul" - input: "Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "Neg" - op: "Neg" - input: "Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "gradients/Shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } - } - node { - name: "gradients/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } - } - node { - name: "gradients/Fill" - op: "Fill" - input: "gradients/Shape" - input: "gradients/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "gradients/Neg_grad/Neg" - op: "Neg" - input: "gradients/Fill" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "gradients/Sum_grad/Reshape/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\001\000\000\000\001\000\000\000" - } - } - } - } - node { - name: "gradients/Sum_grad/Reshape" - op: "Reshape" - input: "gradients/Neg_grad/Neg" - input: "gradients/Sum_grad/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - } - } - } - } - } - node { - name: "gradients/Sum_grad/Shape" - op: "Shape" - input: "mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node { - name: "gradients/Sum_grad/Tile" - op: "Tile" - input: "gradients/Sum_grad/Reshape" - input: "gradients/Sum_grad/Shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tmultiples" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/mul_grad/Shape" - op: "Shape" - input: "Placeholder" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node { - name: "gradients/mul_grad/Shape_1" - op: "Shape" - input: "Log" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node { - name: "gradients/mul_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/mul_grad/Shape" - input: "gradients/mul_grad/Shape_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } - } - node { - name: "gradients/mul_grad/mul" - op: "Mul" - input: "gradients/Sum_grad/Tile" - input: "Log" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/mul_grad/Sum" - op: "Sum" - input: "gradients/mul_grad/mul" - input: "gradients/mul_grad/BroadcastGradientArgs" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "gradients/mul_grad/Reshape" - op: "Reshape" - input: "gradients/mul_grad/Sum" - input: "gradients/mul_grad/Shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/mul_grad/mul_1" - op: "Mul" - input: "Placeholder" - input: "gradients/Sum_grad/Tile" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/mul_grad/Sum_1" - op: "Sum" - input: "gradients/mul_grad/mul_1" - input: "gradients/mul_grad/BroadcastGradientArgs:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "gradients/mul_grad/Reshape_1" - op: "Reshape" - input: "gradients/mul_grad/Sum_1" - input: "gradients/mul_grad/Shape_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/mul_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/mul_grad/Reshape" - input: "^gradients/mul_grad/Reshape_1" - } - node { - name: "gradients/mul_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/mul_grad/Reshape" - input: "^gradients/mul_grad/tuple/group_deps" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/mul_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/mul_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/mul_grad/Reshape_1" - input: "^gradients/mul_grad/tuple/group_deps" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/mul_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/Log_grad/Reciprocal" - op: "Reciprocal" - input: "y" - input: "^gradients/mul_grad/tuple/control_dependency_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/Log_grad/mul" - op: "Mul" - input: "gradients/mul_grad/tuple/control_dependency_1" - input: "gradients/Log_grad/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/y_grad/mul" - op: "Mul" - input: "gradients/Log_grad/mul" - input: "y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/y_grad/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node { - name: "gradients/y_grad/Sum" - op: "Sum" - input: "gradients/y_grad/mul" - input: "gradients/y_grad/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "gradients/y_grad/Reshape/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\377\377\377\377\001\000\000\000" - } - } - } - } - node { - name: "gradients/y_grad/Reshape" - op: "Reshape" - input: "gradients/y_grad/Sum" - input: "gradients/y_grad/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 1 - } - } - } - } - } - } - node { - name: "gradients/y_grad/sub" - op: "Sub" - input: "gradients/Log_grad/mul" - input: "gradients/y_grad/Reshape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/y_grad/mul_1" - op: "Mul" - input: "gradients/y_grad/sub" - input: "y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/add_grad/Shape" - op: "Shape" - input: "MatMul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node { - name: "gradients/add_grad/Shape_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 10 - } - } - } - } - node { - name: "gradients/add_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_grad/Shape" - input: "gradients/add_grad/Shape_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } - } - node { - name: "gradients/add_grad/Sum" - op: "Sum" - input: "gradients/y_grad/mul_1" - input: "gradients/add_grad/BroadcastGradientArgs" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "gradients/add_grad/Reshape" - op: "Reshape" - input: "gradients/add_grad/Sum" - input: "gradients/add_grad/Shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/add_grad/Sum_1" - op: "Sum" - input: "gradients/y_grad/mul_1" - input: "gradients/add_grad/BroadcastGradientArgs:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "gradients/add_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_grad/Sum_1" - input: "gradients/add_grad/Shape_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/add_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_grad/Reshape" - input: "^gradients/add_grad/Reshape_1" - } - node { - name: "gradients/add_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_grad/Reshape" - input: "^gradients/add_grad/tuple/group_deps" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/add_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_grad/Reshape_1" - input: "^gradients/add_grad/tuple/group_deps" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/MatMul_grad/MatMul" - op: "MatMul" - input: "gradients/add_grad/tuple/control_dependency" - input: "Variable/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } - } - node { - name: "gradients/MatMul_grad/MatMul_1" - op: "MatMul" - input: "x" - input: "gradients/add_grad/tuple/control_dependency" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" - value { - b: false - } - } - } - node { - name: "gradients/MatMul_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_grad/MatMul" - input: "^gradients/MatMul_grad/MatMul_1" - } - node { - name: "gradients/MatMul_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/MatMul_grad/MatMul" - input: "^gradients/MatMul_grad/tuple/group_deps" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_grad/MatMul" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - } - } - } - } - node { - name: "gradients/MatMul_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_grad/MatMul_1" - input: "^gradients/MatMul_grad/tuple/group_deps" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_grad/MatMul_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "GradientDescent/learning_rate" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00999999977648 - } - } - } - } - node { - name: "GradientDescent/update_Variable/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "Variable" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_grad/tuple/control_dependency_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } - } - node { - name: "GradientDescent/update_Variable_1/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "Variable_1" - input: "GradientDescent/learning_rate" - input: "gradients/add_grad/tuple/control_dependency_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } - } - node { - name: "GradientDescent" - op: "NoOp" - input: "^GradientDescent/update_Variable/ApplyGradientDescent" - input: "^GradientDescent/update_Variable_1/ApplyGradientDescent" - } - node { - name: "TopKV2/k" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 10 - } - } - } - } - node { - name: "TopKV2" - op: "TopKV2" - input: "y" - input: "TopKV2/k" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "sorted" - value { - b: true - } - } - } - node { - name: "Const_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 10 - } - } - string_val: "0" - string_val: "1" - string_val: "2" - string_val: "3" - string_val: "4" - string_val: "5" - string_val: "6" - string_val: "7" - string_val: "8" - string_val: "9" - } - } - } - } - node { - name: "index_to_string/Size" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 10 - } - } - } - } - node { - name: "index_to_string/range/start" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - } - node { - name: "index_to_string/range/delta" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node { - name: "index_to_string/range" - op: "Range" - input: "index_to_string/range/start" - input: "index_to_string/Size" - input: "index_to_string/range/delta" - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - } - node { - name: "index_to_string/ToInt64" - op: "Cast" - input: "index_to_string/range" - attr { - key: "DstT" - value { - type: DT_INT64 - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - } - node { - name: "index_to_string" - op: "HashTableV2" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "key_dtype" - value { - type: DT_INT64 - } - } - attr { - key: "shared_name" - value { - s: "" - } - } - attr { - key: "use_node_name_sharing" - value { - b: false - } - } - attr { - key: "value_dtype" - value { - type: DT_STRING - } - } - } - node { - name: "index_to_string/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "UNK" - } - } - } - } - node { - name: "index_to_string/table_init" - op: "InitializeTableV2" - input: "index_to_string" - input: "index_to_string/ToInt64" - input: "Const_1" - attr { - key: "Tkey" - value { - type: DT_INT64 - } - } - attr { - key: "Tval" - value { - type: DT_STRING - } - } - } - node { - name: "ToInt64" - op: "Cast" - input: "TopKV2:1" - attr { - key: "DstT" - value { - type: DT_INT64 - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "index_to_string_Lookup" - op: "LookupTableFindV2" - input: "index_to_string" - input: "ToInt64" - input: "index_to_string/Const" - attr { - key: "Tin" - value { - type: DT_INT64 - } - } - attr { - key: "Tout" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "ArgMax/dimension" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node { - name: "ArgMax" - op: "ArgMax" - input: "y" - input: "ArgMax/dimension" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_type" - value { - type: DT_INT64 - } - } - } - node { - name: "ArgMax_1/dimension" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node { - name: "ArgMax_1" - op: "ArgMax" - input: "Placeholder" - input: "ArgMax_1/dimension" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_type" - value { - type: DT_INT64 - } - } - } - node { - name: "Equal" - op: "Equal" - input: "ArgMax" - input: "ArgMax_1" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - } - node { - name: "Cast" - op: "Cast" - input: "Equal" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - } - node { - name: "Const_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node { - name: "Mean" - op: "Mean" - input: "Cast" - input: "Const_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "init_all_tables" - op: "NoOp" - input: "^index_to_string/table_init" - } - node { - name: "legacy_init_op" - op: "NoOp" - input: "^init_all_tables" - } - node { - name: "save/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "model" - } - } - } - } - node { - name: "save/StringJoin/inputs_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "_temp_8390c48b96834292ab57050d0ae6959e/part" - } - } - } - } - node { - name: "save/StringJoin" - op: "StringJoin" - input: "save/Const" - input: "save/StringJoin/inputs_1" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "separator" - value { - s: "" - } - } - } - node { - name: "save/num_shards" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node { - name: "save/ShardedFilename/shard" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - } - node { - name: "save/ShardedFilename" - op: "ShardedFilename" - input: "save/StringJoin" - input: "save/ShardedFilename/shard" - input: "save/num_shards" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "save/SaveV2/tensor_names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 2 - } - } - string_val: "Variable" - string_val: "Variable_1" - } - } - } - } - node { - name: "save/SaveV2/shape_and_slices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 2 - } - } - string_val: "" - string_val: "" - } - } - } - } - node { - name: "save/SaveV2" - op: "SaveV2" - input: "save/ShardedFilename" - input: "save/SaveV2/tensor_names" - input: "save/SaveV2/shape_and_slices" - input: "Variable" - input: "Variable_1" - attr { - key: "dtypes" - value { - list { - type: DT_FLOAT - type: DT_FLOAT - } - } - } - } - node { - name: "save/control_dependency" - op: "Identity" - input: "save/ShardedFilename" - input: "^save/SaveV2" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@save/ShardedFilename" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "save/MergeV2Checkpoints/checkpoint_prefixes" - op: "Pack" - input: "save/ShardedFilename" - input: "^save/control_dependency" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } - } - node { - name: "save/MergeV2Checkpoints" - op: "MergeV2Checkpoints" - input: "save/MergeV2Checkpoints/checkpoint_prefixes" - input: "save/Const" - attr { - key: "delete_old_dirs" - value { - b: true - } - } - } - node { - name: "save/Identity" - op: "Identity" - input: "save/Const" - input: "^save/control_dependency" - input: "^save/MergeV2Checkpoints" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "save/RestoreV2/tensor_names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "Variable" - } - } - } - } - node { - name: "save/RestoreV2/shape_and_slices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "" - } - } - } - } - node { - name: "save/RestoreV2" - op: "RestoreV2" - input: "save/Const" - input: "save/RestoreV2/tensor_names" - input: "save/RestoreV2/shape_and_slices" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "dtypes" - value { - list { - type: DT_FLOAT - } - } - } - } - node { - name: "save/Assign" - op: "Assign" - input: "Variable" - input: "save/RestoreV2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "save/RestoreV2_1/tensor_names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "Variable_1" - } - } - } - } - node { - name: "save/RestoreV2_1/shape_and_slices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "" - } - } - } - } - node { - name: "save/RestoreV2_1" - op: "RestoreV2" - input: "save/Const" - input: "save/RestoreV2_1/tensor_names" - input: "save/RestoreV2_1/shape_and_slices" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "dtypes" - value { - list { - type: DT_FLOAT - } - } - } - } - node { - name: "save/Assign_1" - op: "Assign" - input: "Variable_1" - input: "save/RestoreV2_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "save/restore_shard" - op: "NoOp" - input: "^save/Assign" - input: "^save/Assign_1" - } - node { - name: "save/restore_all" - op: "NoOp" - input: "^save/restore_shard" - } - versions { - producer: 24 - } - } - saver_def { - filename_tensor_name: "save/Const:0" - save_tensor_name: "save/Identity:0" - restore_op_name: "save/restore_all" - max_to_keep: 5 - sharded: true - keep_checkpoint_every_n_hours: 10000.0 - version: V2 - } - collection_def { - key: "legacy_init_op" - value { - node_list { - value: "legacy_init_op" - } - } - } - collection_def { - key: "table_initializer" - value { - node_list { - value: "index_to_string/table_init" - } - } - } - collection_def { - key: "train_op" - value { - node_list { - value: "GradientDescent" - } - } - } - collection_def { - key: "trainable_variables" - value { - bytes_list { - value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:0" - value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:0" - } - } - } - collection_def { - key: "variables" - value { - bytes_list { - value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:0" - value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:0" - } - } - } - signature_def { - key: "predict_images" - value { - inputs { - key: "images" - value { - name: "x:0" - dtype: DT_FLOAT - tensor_shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - } - } - outputs { - key: "scores" - value { - name: "y:0" - dtype: DT_FLOAT - tensor_shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - method_name: "tensorflow/serving/predict" - } - } - signature_def { - key: "serving_default" - value { - inputs { - key: "inputs" - value { - name: "tf_example:0" - dtype: DT_STRING - tensor_shape { - unknown_rank: true - } - } - } - outputs { - key: "classes" - value { - name: "index_to_string_Lookup:0" - dtype: DT_STRING - tensor_shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - outputs { - key: "scores" - value { - name: "TopKV2:0" - dtype: DT_FLOAT - tensor_shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - method_name: "tensorflow/serving/classify" - } - } -} diff --git a/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.data-00000-of-00001 deleted file mode 100644 index ba71c21fbe1..00000000000 Binary files a/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.data-00000-of-00001 and /dev/null differ diff --git a/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.index deleted file mode 100644 index 84f4593515a..00000000000 Binary files a/searchlib/src/test/files/integration/tensorflow/model1/variables/variables.index and /dev/null differ diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 82e5d0cfe5b..3aa2d144f1f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.*; -import com.yahoo.tensor.Tensor; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.IfNode; import org.junit.Test; + import static org.junit.Assert.assertEquals; /** @@ -83,7 +88,7 @@ public class EvaluationTestCase { tester.assertEvaluates(0, "sin(0)"); tester.assertEvaluates(1, "cos(0)"); tester.assertEvaluates(8, "pow(4/2,min(cos(0)*3,5))"); - + // Random feature (which is also a tensor function) (We expect to be able to parse it and look up a zero) tester.assertEvaluates(0, "random(1)"); tester.assertEvaluates(0, "random(foo)"); @@ -152,7 +157,7 @@ public class EvaluationTestCase { "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", "!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }"); - + // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence) tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }"); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java index ee2b1c147e3..ba0db4de5e1 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java @@ -34,7 +34,7 @@ public class EvaluationTester { } // TODO: Test both bound and unbound indexed - public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors, + public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors, String ... tensorArgumentStrings) { MapContext context = defaultContext.thawedCopy(); int argumentIndex = 0; @@ -46,7 +46,7 @@ public class EvaluationTester { argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString); context.put("tensor" + (argumentIndex++), new TensorValue(argument)); } - return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context, + return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context, mappedTensors ? "Mapped tensors" : "Indexed tensors"); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java new file mode 100644 index 00000000000..dab42801d70 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java @@ -0,0 +1,114 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; + +import java.nio.FloatBuffer; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author bratseth + */ +public class Mnist_SoftmaxTestCase { + + @Test + public void testImporting() { + String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved"; + ImportResult result = new TensorFlowImporter().importModel(modelDir); + + // Check logged messages + result.warnings().forEach(System.err::println); + assertEquals(0, result.warnings().size()); + + // Check arguments + assertEquals(1, result.arguments().size()); + TensorType argument0 = result.arguments().get("Placeholder"); + assertNotNull(argument0); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); + + // Check constants + assertEquals(2, result.constants().size()); + + Tensor constant0 = result.constants().get("Variable"); + assertNotNull(constant0); + assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), + constant0.type()); + assertEquals(7840, constant0.size()); + + Tensor constant1 = result.constants().get("Variable_1"); + assertNotNull(constant1); + assertEquals(new TensorType.Builder().indexed("d0", 10).build(), + constant1.type()); + assertEquals(10, constant1.size()); + + // Check resulting Vespa expression + assertEquals(1, result.expressions().size()); + assertEquals("y", result.expressions().get(0).getName()); + assertEquals("" + + "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " + + "rename(constant(Variable_1), d0, d1), " + + "f(a,b)(a + b))", + toNonPrimitiveString(result.expressions().get(0))); + + // Test execution + String signatureName = "serving_default"; + + assertEqualResult(modelDir, signatureName, "Variable/read"); + assertEqualResult(modelDir, signatureName, "Variable_1/read"); + // TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder"); + assertEqualResult(modelDir, signatureName, "MatMul"); + assertEqualResult(modelDir, signatureName, "add"); + } + + private void assertEqualResult(String modelDir, String signatureName, String operationName) { + ImportResult result = new TensorFlowImporter().importNode(modelDir, signatureName, operationName); + + Tensor tfResult = tensorFlowExecute(modelDir, operationName); + Context context = contextFrom(result); + Tensor placeholder = placeholderArgument(); + context.put("Placeholder", new TensorValue(placeholder)); + Tensor vespaResult = result.expressions().get(0).evaluate(context).asTensor(); + assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult); + } + + private Tensor tensorFlowExecute(String modelDir, String operationName) { + SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); + Session.Runner runner = model.session().runner(); + org.tensorflow.Tensor placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784)); + runner.feed("Placeholder", placeholder); + List> results = runner.fetch(operationName).run(); + assertEquals(1, results.size()); + return new TensorConverter().toVespaTensor(results.get(0)); + } + + private Context contextFrom(ImportResult result) { + MapContext context = new MapContext(); + result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + return context; + } + + private String toNonPrimitiveString(RankingExpression expression) { + // toString on the wrapping expression will map to primitives, which is harder to read + return ((TensorFunctionNode)expression.getRoot()).function().toString(); + } + + private Tensor placeholderArgument() { + int size = 784; + Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build()); + for (int i = 0; i < size; i++) + b.cell(0, 0, i); + return b.build(); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java deleted file mode 100644 index aaf198a9e8f..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java +++ /dev/null @@ -1,79 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.TensorType; -import org.junit.Test; - -import java.util.ArrayList; -import java.util.List; -import java.util.logging.Level; -import java.util.stream.Collectors; - -import static org.junit.Assert.assertEquals; - -/** - * @author bratseth - */ -public class TensorFlowImporterTestCase { - - @Test - public void testModel1() { - List constants = new ArrayList<>(); - TestLogger logger = new TestLogger(); - List expressions = - new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/", constants, logger); - - // Check constants - assertEquals(2, constants.size()); - - assertEquals("Variable", constants.get(0).name()); - assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), - constants.get(0).tensor().type()); - assertEquals(7840, constants.get(0).tensor().size()); - - assertEquals("Variable_1", constants.get(1).name()); - assertEquals(new TensorType.Builder().indexed("d0", 10).build(), - constants.get(1).tensor().type()); - assertEquals(10, constants.get(1).tensor().size()); - - // Check logged messages - assertEquals(2, logger.messages().size()); - assertEquals("Skipping output 'TopKV2:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'TopKV2' is not supported", - logger.messages().get(0)); - assertEquals("Skipping output 'index_to_string_Lookup:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'LookupTableFindV2' is not supported", - logger.messages().get(1)); - - // Check resulting Vespa expression - assertEquals(1, expressions.size()); - assertEquals("scores", expressions.get(0).getName()); - assertEquals("" + - "softmax(join(rename(matmul(x, rename(constant(Variable), (d1, d2), (d2, d3)), d2), d3, d2), " + - "constant(Variable_1), " + - "f(a,b)(a + b)), " + - "d0)", - toNonPrimitiveString(expressions.get(0))); - } - - private String toNonPrimitiveString(RankingExpression expression) { - // toString on the wrapping expression will map to primitives, which is harder to read - return ((TensorFunctionNode)expression.getRoot()).function().toString(); - } - - private class TestLogger implements TensorFlowImporter.MessageLogger { - - private List messages = new ArrayList<>(); - - /** Returns the messages in sorted order */ - public List messages() { - return messages.stream().sorted().collect(Collectors.toList()); - } - - @Override - public void log(Level level, String message) { - messages.add(message); - } - - } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java index dde9d4bf21e..1960c1fe876 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java @@ -59,7 +59,7 @@ public class TensorConformanceTest { try { ObjectMapper mapper = new ObjectMapper(); JsonNode node = mapper.readTree(test); - + if (node.has("num_tests")) { Assert.assertEquals(node.get("num_tests").asInt(), count); return true; @@ -67,7 +67,7 @@ public class TensorConformanceTest { if (!node.has("expression")) { return true; // ignore } - + String expression = node.get("expression").asText(); MapContext context = getInput(node.get("inputs")); Tensor expect = getTensor(node.get("result").get("expect").asText()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index 00e106dd035..f6237a1977a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -7,7 +7,7 @@ import java.util.Arrays; /** * The sizes of a set of dimensions. - * + * * @author bratseth */ @Beta @@ -48,7 +48,7 @@ public final class DimensionSizes { @Override public int hashCode() { return Arrays.hashCode(sizes); } - /** + /** * Builder of a set of dimension sizes. * Dimensions whose size is not set before building will get size 0. */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index c207dabca3a..6b0d769de9f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -25,12 +25,12 @@ public class IndexedTensor implements Tensor { /** The prescribed and possibly abstract type this is an instance of */ private final TensorType type; - + /** The sizes of the dimensions of this in the order of the dimensions of the type */ private final DimensionSizes dimensionSizes; - + private final double[] values; - + private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) { this.type = type; this.dimensionSizes = dimensionSizes; @@ -43,8 +43,8 @@ public class IndexedTensor implements Tensor { } /** - * Returns an iterator over the cells of this. - * Cells are returned in order of increasing indexes in each dimension, increasing + * Returns an iterator over the cells of this. + * Cells are returned in order of increasing indexes in each dimension, increasing * indexes of later dimensions in the dimension type before earlier. */ @Override @@ -69,7 +69,7 @@ public class IndexedTensor implements Tensor { /** * Returns an iterator over the values of this. - * Values are returned in order of increasing indexes in each dimension, increasing + * Values are returned in order of increasing indexes in each dimension, increasing * indexes of later dimensions in the dimension type before earlier. */ @Override @@ -81,7 +81,7 @@ public class IndexedTensor implements Tensor { * Returns an iterator over value iterators where the outer iterator is over each unique value of the dimensions * given and the inner iterator is over each unique value of the rest of the dimensions, in the same order as * other iterator. - * + * * @param dimensions the names of the dimensions of the superspace * @param sizes the size of each dimension in the space we are returning values for, containing * one value per dimension of this tensor (in order). Each size may be the same or smaller @@ -96,9 +96,9 @@ public class IndexedTensor implements Tensor { return subspaceIterator(dimensions, dimensionSizes); } - /** + /** * Returns the value at the given indexes - * + * * @param indexes the indexes into the dimensions of this. Must be one number per dimension of this * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given */ @@ -119,7 +119,7 @@ public class IndexedTensor implements Tensor { } private double get(int valueIndex) { return values[valueIndex]; } - + private static int toValueIndex(int[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed if (indexes.length == 0) return 0; // for speed @@ -165,7 +165,7 @@ public class IndexedTensor implements Tensor { public Map cells() { if (dimensionSizes.dimensions() == 0) return Collections.singletonMap(TensorAddress.of(), values[0]); - + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); for (int i = 0; i < values.length; i++) { @@ -174,13 +174,13 @@ public class IndexedTensor implements Tensor { } return builder.build(); } - + @Override public int hashCode() { return Arrays.hashCode(values); } @Override public String toString() { return Tensor.toStandardString(this); } - + @Override public boolean equals(Object other) { if ( ! ( other instanceof Tensor)) return false; @@ -188,9 +188,9 @@ public class IndexedTensor implements Tensor { } public abstract static class Builder implements Tensor.Builder { - + final TensorType type; - + private Builder(TensorType type) { this.type = type; } @@ -202,7 +202,7 @@ public class IndexedTensor implements Tensor { return new UnboundBuilder(type); } - /** + /** * Create a builder with dimension size information for this instance. Must be one size entry per dimension, * and, agree with the type size information when specified in the type. * If sizes are completely specified in the type this size information is redundant. @@ -210,16 +210,16 @@ public class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes) { // validate if (sizes.dimensions() != type.dimensions().size()) - throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " + + throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " + "for " + type); for (int i = 0; i < sizes.dimensions(); i++ ) { Optional size = type.dimensions().get(i).size(); if (size.isPresent() && size.get() < sizes.size(i)) - throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + + throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + sizes.size(i) + " but cannot be larger than " + size.get() + " in " + type); } - + return new BoundBuilder(type, sizes); } @@ -232,7 +232,7 @@ public class IndexedTensor implements Tensor { public abstract IndexedTensor build(); } - + /** A bound builder can create the double array directly */ public static class BoundBuilder extends Builder { @@ -257,13 +257,13 @@ public class IndexedTensor implements Tensor { this.sizes = sizes; values = new double[sizes.totalSize()]; } - + @Override public BoundBuilder cell(double value, int ... indexes) { values[toValueIndex(indexes, sizes)] = value; return this; } - + @Override public CellBuilder cell() { return new CellBuilder(type, this); @@ -294,8 +294,8 @@ public class IndexedTensor implements Tensor { return this; } - /** - * Set a cell value by the index in the internal layout of this cell. + /** + * Set a cell value by the index in the internal layout of this cell. * This requires knowledge of the internal layout of cells in this implementation, and should therefore * probably not be used (but when it can be used it is fast). */ @@ -330,7 +330,7 @@ public class IndexedTensor implements Tensor { fillValues(0, 0, firstDimension, dimensionSizes, values); return new IndexedTensor(type, dimensionSizes, values); } - + private DimensionSizes findDimensionSizes(List firstDimension) { List dimensionSizeList = new ArrayList<>(type.dimensions().size()); findDimensionSizes(0, dimensionSizeList, firstDimension); @@ -347,16 +347,16 @@ public class IndexedTensor implements Tensor { if (currentDimensionIndex == dimensionSizes.size()) dimensionSizes.add(currentDimension.size()); else if (dimensionSizes.get(currentDimensionIndex) != currentDimension.size()) - throw new IllegalArgumentException("Missing values in dimension " + + throw new IllegalArgumentException("Missing values in dimension " + type.dimensions().get(currentDimensionIndex) + " in " + type); - + for (Object value : currentDimension) if (value instanceof List) findDimensionSizes(currentDimensionIndex + 1, dimensionSizes, (List)value); } @SuppressWarnings("unchecked") - private void fillValues(int currentDimensionIndex, int offset, List currentDimension, + private void fillValues(int currentDimensionIndex, int offset, List currentDimension, DimensionSizes sizes, double[] values) { if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension for (int i = 0; i < currentDimension.size(); i++) @@ -369,7 +369,7 @@ public class IndexedTensor implements Tensor { } } } - + private double nullAsZero(Double value) { if (value == null) return 0; return value; @@ -431,7 +431,7 @@ public class IndexedTensor implements Tensor { } } - + private final class CellIterator implements Iterator { private int count = 0; @@ -451,7 +451,7 @@ public class IndexedTensor implements Tensor { reusedCell.value = get(indexes.toSourceValueIndex()); return reusedCell; } - + } private final class ValueIterator implements Iterator { @@ -474,25 +474,25 @@ public class IndexedTensor implements Tensor { } } - + private final class SuperspaceIterator implements Iterator { private final Indexes superindexes; /** Those indexes this should iterate over */ private final List subdimensionIndexes; - - /** + + /** * The sizes of the space we'll return values of, one value for each dimension of this tensor, - * which may be equal to or smaller than the sizes of this tensor + * which may be equal to or smaller than the sizes of this tensor */ private final DimensionSizes iterateSizes; private int count = 0; - + private SuperspaceIterator(Set superdimensionNames, DimensionSizes iterateSizes) { this.iterateSizes = iterateSizes; - + List superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length) for (int i = type.dimensions().size() - 1; i >= 0; i-- ) { // iterate inner dimensions first @@ -501,10 +501,10 @@ public class IndexedTensor implements Tensor { else subdimensionIndexes.add(i); } - + superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, superdimensionIndexes); } - + @Override public boolean hasNext() { return count < superindexes.size(); @@ -527,7 +527,7 @@ public class IndexedTensor implements Tensor { */ public final class SubspaceIterator implements Iterator { - /** + /** * This iterator will iterate over the given dimensions, in the order given * (the first dimension index given is incremented to exhaustion first (i.e is etc.). * This may be any subset of the dimensions given by address and dimensionSizes. @@ -538,21 +538,21 @@ public class IndexedTensor implements Tensor { private Indexes indexes; private int count = 0; - + /** A lazy cell for reuse */ private final LazyCell reusedCell; - - /** + + /** * Creates a new subspace iterator - * + * * @param iterateDimensions the dimensions to iterate over, given as indexes in the dimension order of the * type of the tensor this iterates over. This iterator will iterate over these - * dimensions to exhaustion in the order given (the first dimension index given is + * dimensions to exhaustion in the order given (the first dimension index given is * incremented to exhaustion first (i.e is etc.), while other dimensions will be held * at a constant position. * This may be any subset of the dimensions given by address and dimensionSizes. * This is treated as immutable. - * @param address the address of the first cell of this subspace. + * @param address the address of the first cell of this subspace. */ private SubspaceIterator(List iterateDimensions, int[] address, DimensionSizes iterateSizes) { this.iterateDimensions = iterateDimensions; @@ -561,26 +561,26 @@ public class IndexedTensor implements Tensor { this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); reusedCell = new LazyCell(indexes, Double.NaN); } - + /** Returns the total number of cells in this subspace */ - public int size() { + public int size() { return indexes.size(); } - + /** Returns the address of the cell this currently points to (which may be an invalid position) */ public TensorAddress address() { return indexes.toAddress(); } - + /** Rewind this iterator to the first element */ - public void reset() { + public void reset() { this.count = 0; - this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); + this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); } - + @Override public boolean hasNext() { - return count < indexes.size(); + return count < indexes.size(); } - + /** Returns the next cell, which is valid until next() is called again */ @Override public Cell next() { @@ -611,15 +611,15 @@ public class IndexedTensor implements Tensor { public TensorAddress getKey() { return indexes.toAddress(); } - + @Override public Double getValue() { return value; } } // TODO: Make dimensionSizes a class - - /** + + /** * An array of indexes into this tensor which are able to find the next index in the value order. * next() can be called once per element in the dimensions we iterate over. It must be called once * before accessing the first position. @@ -631,7 +631,7 @@ public class IndexedTensor implements Tensor { private final DimensionSizes iterationSizes; protected final int[] indexes; - + public static Indexes of(DimensionSizes sizes) { return of(sizes, sizes); } @@ -676,14 +676,14 @@ public class IndexedTensor implements Tensor { return new MultiDimensionIndexes(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, size); } } - + private static List completeIterationOrder(int length) { List iterationDimensions = new ArrayList<>(length); for (int i = 0; i < length; i++) iterationDimensions.add(length - 1 - i); return iterationDimensions; } - + private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, int[] indexes) { this.sourceSizes = sourceSizes; this.iterationSizes = iterationSizes; @@ -708,9 +708,9 @@ public class IndexedTensor implements Tensor { /** Returns a copy of the indexes of this which must not be modified */ public int[] indexesForReading() { return indexes; } - - int toSourceValueIndex() { - return IndexedTensor.toValueIndex(indexes, sourceSizes); + + int toSourceValueIndex() { + return IndexedTensor.toValueIndex(indexes, sourceSizes); } int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); } @@ -729,9 +729,9 @@ public class IndexedTensor implements Tensor { public String toString() { return "indexes " + Arrays.toString(indexes); } - + public abstract int size(); - + public abstract void next(); } @@ -763,18 +763,18 @@ public class IndexedTensor implements Tensor { public void next() {} } - + private static class MultiDimensionIndexes extends Indexes { private final int size; private final List iterateDimensions; - + private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List iterateDimensions, int[] initialIndexes, int size) { super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimensions = iterateDimensions; this.size = size; - + // Initialize to the (virtual) position before the first cell indexes[iterateDimensions.get(0)]--; } @@ -785,10 +785,10 @@ public class IndexedTensor implements Tensor { return size; } - /** - * Advances this to the next cell in the standard indexed tensor cell order. - * The first call to this will put it at the first position. - * + /** + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. + * * @throws RuntimeException if this is called more times than its size */ @Override @@ -802,12 +802,12 @@ public class IndexedTensor implements Tensor { } } - + /** In this case we can reuse the source index computation for the iteration index */ private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes { private int lastComputedSourceValueIndex = -1; - + private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List iterateDimensions, int[] initialIndexes, int size) { super(sizes, sizes, iterateDimensions, initialIndexes, size); } @@ -827,7 +827,7 @@ public class IndexedTensor implements Tensor { private final int size; private final int iterateDimension; - + /** Maintain this directly as an optimization for 1-d iteration */ private int currentSourceValueIndex, currentIterationValueIndex; @@ -847,7 +847,7 @@ public class IndexedTensor implements Tensor { currentSourceValueIndex = IndexedTensor.toValueIndex(indexes, sourceSizes); currentIterationValueIndex = IndexedTensor.toValueIndex(indexes, iterateSizes); } - + /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @Override public int size() { @@ -855,8 +855,8 @@ public class IndexedTensor implements Tensor { } /** - * Advances this to the next cell in the standard indexed tensor cell order. - * The first call to this will put it at the first position. + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. * * @throws RuntimeException if this is called more times than its size */ @@ -888,7 +888,7 @@ public class IndexedTensor implements Tensor { /** The iteration step in the value index space */ private final int step; - private EqualSizeSingleDimensionIndexes(DimensionSizes sizes, + private EqualSizeSingleDimensionIndexes(DimensionSizes sizes, int iterateDimension, int[] initialIndexes, int size) { super(sizes, sizes, initialIndexes); this.iterateDimension = iterateDimension; @@ -907,8 +907,8 @@ public class IndexedTensor implements Tensor { } /** - * Advances this to the next cell in the standard indexed tensor cell order. - * The first call to this will put it at the first position. + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. * * @throws RuntimeException if this is called more times than its size */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 618bff0caae..aba61478e69 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -27,7 +27,7 @@ public class MappedTensor implements Tensor { @Override public TensorType type() { return type; } - + @Override public int size() { return cells.size(); } @@ -56,16 +56,16 @@ public class MappedTensor implements Tensor { } public static class Builder implements Tensor.Builder { - + private final TensorType type; private final ImmutableMap.Builder cells = new ImmutableMap.Builder<>(); - + public static Builder of(TensorType type) { return new Builder(type); } private Builder(TensorType type) { this.type = type; } - + public CellBuilder cell() { return new CellBuilder(type, this); } @@ -89,24 +89,24 @@ public class MappedTensor implements Tensor { public MappedTensor build() { return new MappedTensor(type, cells.build()); } - + } private static class CellIteratorAdaptor implements Iterator { private final Iterator> adaptedIterator; - + private CellIteratorAdaptor(Iterator> adaptedIterator) { this.adaptedIterator = adaptedIterator; } - + @Override public boolean hasNext() { return adaptedIterator.hasNext(); } @Override public Cell next() { Map.Entry entry = adaptedIterator.next(); - return new Cell(entry.getKey(), entry.getValue()); + return new Cell(entry.getKey(), entry.getValue()); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 79bb27fcd1b..9a751e078e0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -117,7 +117,7 @@ public class MixedTensor implements Tensor { return index.denseSubspaceSize(); } - + /** * Base class for building mixed tensors. */ @@ -286,7 +286,7 @@ public class MixedTensor implements Tensor { } return typeBuilder.build(); } - + } /** @@ -360,7 +360,7 @@ public class MixedTensor implements Tensor { } return denseSubspaceSize; } - + private TensorAddress sparsePartialAddress(TensorAddress address) { if (type.dimensions().size() != address.size()) { throw new IllegalArgumentException("Tensor type and address are not of same size."); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 2ed211539d8..1b60e01cf7e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -52,7 +52,7 @@ import java.util.function.Function; public interface Tensor { // ----------------- Accessors - + TensorType type(); /** Returns whether this have any cells */ @@ -70,13 +70,13 @@ public interface Tensor { /** Returns the values of this in some undefined order */ Iterator valueIterator(); - /** + /** * Returns an immutable map of the cells of this in no particular order. - * This may be expensive for some implementations - avoid when possible + * This may be expensive for some implementations - avoid when possible */ Map cells(); - /** + /** * Returns the value of this as a double if it has no dimensions and one value * * @throws IllegalStateException if this does not have zero dimensions and one value @@ -87,9 +87,9 @@ public interface Tensor { if (size() == 0) return Double.NaN; return valueIterator().next(); } - + // ----------------- Primitive tensor functions - + default Tensor map(DoubleUnaryOperator mapper) { return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate(); } @@ -108,7 +108,7 @@ public interface Tensor { } default Tensor rename(String fromDimension, String toDimension) { - return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), + return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), Collections.singletonList(toDimension)).evaluate(); } @@ -123,13 +123,13 @@ public interface Tensor { default Tensor rename(List fromDimensions, List toDimensions) { return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate(); } - + static Tensor generate(TensorType type, Function, Double> valueSupplier) { return new Generate(type, valueSupplier).evaluate(); } - + // ----------------- Composite tensor functions which have a defined primitive mapping - + default Tensor l1Normalize(String dimension) { return new L1Normalize(new ConstantTensor(this), dimension).evaluate(); } @@ -231,7 +231,7 @@ public interface Tensor { if (cellEntries.isEmpty()) return "{}"; return "{" + cellEntries.get(0).getValue() +"}"; } - + Collections.sort(cellEntries, java.util.Map.Entry.comparingByKey()); StringBuilder b = new StringBuilder("{"); @@ -253,7 +253,7 @@ public interface Tensor { */ boolean equals(Object o); - /** + /** * Implement here to make this work across implementations. * Implementations must override equals and call this because this is an interface and cannot override equals. */ @@ -328,13 +328,13 @@ public interface Tensor { @Override public TensorAddress getKey() { return address; } - /** + /** * Returns the direct index which can be used to locate this cell, or -1 if not available. * This is for optimizations mapping between tensors where this is possible without creating a * TensorAddress. */ int getDirectIndex() { return -1; } - + @Override public Double getValue() { return value; } @@ -388,20 +388,20 @@ public interface Tensor { /** Returns the type this is building */ TensorType type(); - + /** Return a cell builder */ CellBuilder cell(); /** Add a cell */ Builder cell(TensorAddress address, double value); - + /** Add a cell */ Builder cell(double value, int ... labels); - /** - * Add a cell - * - * @param cell a cell providing the location at which to add this cell + /** + * Add a cell + * + * @param cell a cell providing the location at which to add this cell * @param value the value to assign to the cell */ default Builder cell(Cell cell, double value) { @@ -409,12 +409,12 @@ public interface Tensor { } Tensor build(); - + class CellBuilder { private final TensorAddress.Builder addressBuilder; private final Tensor.Builder tensorBuilder; - + CellBuilder(TensorType type, Tensor.Builder tensorBuilder) { addressBuilder = new TensorAddress.Builder(type); this.tensorBuilder = tensorBuilder; @@ -436,5 +436,5 @@ public interface Tensor { } } - + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 7161450d5d5..ff1202463f2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -32,10 +32,10 @@ public abstract class TensorAddress implements Comparable { /** Returns the number of labels in this */ public abstract int size(); - + /** - * Returns the i'th label in this - * + * Returns the i'th label in this + * * @throws IllegalArgumentException if there is no label at this index */ public abstract String label(int i); @@ -102,23 +102,23 @@ public abstract class TensorAddress implements Comparable { private StringTensorAddress(String ... labels) { this.labels = Arrays.copyOf(labels, labels.length); } - + @Override public int size() { return labels.length; } - + @Override public String label(int i) { return labels[i]; } - + @Override - public int intLabel(int i) { + public int intLabel(int i) { try { return Integer.parseInt(labels[i]); - } + } catch (NumberFormatException e) { throw new IllegalArgumentException("Expected an int label in " + this + " at position " + i); } } - + @Override public TensorAddress withLabel(int index, int label) { String[] labels = Arrays.copyOf(this.labels, this.labels.length); @@ -169,7 +169,7 @@ public abstract class TensorAddress implements Comparable { private final TensorType type; private final String[] labels; - + public Builder(TensorType type) { this(type, new String[type.dimensions().size()]); } @@ -193,7 +193,7 @@ public abstract class TensorAddress implements Comparable { labels[labelIndex.get()] = label; return this; } - + /** Creates a copy of this which can be modified separately */ public Builder copy() { return new Builder(type, Arrays.copyOf(labels, labels.length)); @@ -202,7 +202,7 @@ public abstract class TensorAddress implements Comparable { public TensorAddress build() { for (int i = 0; i < labels.length; i++) if (labels[i] == null) - throw new IllegalArgumentException("Missing a value for dimension " + + throw new IllegalArgumentException("Missing a value for dimension " + type.dimensions().get(i).name() + " for " + type); return TensorAddress.of(labels); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index da8ab3bb0ec..9b3a9328f07 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -96,7 +96,7 @@ class TensorParser { if (valueEnd < 0) throw new IllegalArgumentException("A tensor string must end by '}'"); } - + TensorAddress address = addressBuilder.build(); Double value = asDouble(address, s.substring(0, valueEnd).trim()); builder.cell(address, value); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index c27ac57415d..914d853aeca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -52,18 +52,18 @@ public class TensorType { public static TensorType fromSpec(String specString) { return TensorTypeParser.fromSpec(specString); } - + /** Returns the number of dimensions of this: dimensions().size() */ public int rank() { return dimensions.size(); } /** Returns an immutable list of the dimensions of this */ public List dimensions() { return dimensions; } - + /** Returns an immutable set of the names of the dimensions of this */ public Set dimensionNames() { return dimensions.stream().map(Dimension::name).collect(Collectors.toSet()); } - + /** Returns the dimension with this name, or empty if not present */ public Optional dimension(String name) { return indexOfDimension(name).map(i -> dimensions.get(i)); @@ -77,7 +77,7 @@ public class TensorType { return Optional.empty(); } - /** + /** * Returns whether this type can be assigned to the given type, * i.e if the given type is a generalization of this type. */ @@ -131,9 +131,9 @@ public class TensorType { private final String name; - private Dimension(String name) { + private Dimension(String name) { Objects.requireNonNull(name, "A tensor name cannot be null"); - this.name = name; + this.name = name; } public final String name() { return name; } @@ -149,7 +149,7 @@ public class TensorType { /** Returns true if this is an indexed bound or unboun type */ public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; } - /** + /** * Returns the dimension resulting from combining two dimensions having the same name but possibly different * types. This works by degrading to the type making the fewer promises. * [N] + [M] = [min(N, M)] @@ -168,7 +168,7 @@ public class TensorType { IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get(); return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; } - + @Override public abstract String toString(); @@ -178,21 +178,21 @@ public class TensorType { if (other == null || getClass() != other.getClass()) return false; return name.equals(((Dimension)other).name); } - + @Override public int hashCode() { return name.hashCode(); } - + @Override public int compareTo(Dimension other) { return this.name.compareTo(other.name); } - + public static Dimension indexed(String name, int size) { return new IndexedBoundDimension(name, size); } - + } public static class IndexedBoundDimension extends TensorType.Dimension { @@ -292,9 +292,9 @@ public class TensorType { public Builder() { } - /** - * Creates a builder containing a combination of the dimensions of the given types - * + /** + * Creates a builder containing a combination of the dimensions of the given types + * * If the same dimension is indexed with different size restrictions the largest size will be used. * If it is size restricted in one argument but not the other it will not be size restricted. * If it is indexed in one and mapped in the other it will become mapped. @@ -328,9 +328,12 @@ public class TensorType { } } - /** + /** Returns the current number of dimensions in this */ + public int rank() { return dimensions.size(); } + + /** * Adds a new dimension to this - * + * * @throws IllegalArgumentException if the dimension is already present */ private Builder add(Dimension dimension) { @@ -349,7 +352,7 @@ public class TensorType { return this; } - /** + /** * Adds a bound indexed dimension to this * * @throws IllegalArgumentException if the dimension is already present @@ -358,7 +361,7 @@ public class TensorType { /** * Adds an unbound indexed dimension to this - * + * * @throws IllegalArgumentException if the dimension is already present */ public Builder indexed(String name) { @@ -378,7 +381,7 @@ public class TensorType { public Builder dimension(Dimension dimension) { return add(dimension); } - + /** Returns the given dimension, or empty if none is present */ public Optional getDimension(String dimension) { return Optional.ofNullable(dimensions.get(dimension)); @@ -396,7 +399,7 @@ public class TensorType { public TensorType build() { return new TensorType(dimensions.values()); } - + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java index 84caca78fb2..3db661f8a23 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java @@ -2,16 +2,17 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; - -import java.util.HashMap; +import com.yahoo.tensor.Tensor; /** * An evaluation context which is passed down to all nested functions during evaluation. - * The default context is empty to allow various evaluation frameworks to support their own implementation. - * + * * @author bratseth */ @Beta public interface EvaluationContext { + /** Returns the tensor bound to this name, or null if none */ + Tensor getTensor(String name); + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index cf704c15f4f..db8a66a5fa2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -18,7 +18,7 @@ public class MapEvaluationContext implements EvaluationContext { public void put(String name, Tensor tensor) { bindings.put(name, tensor); } - /** Returns the tensor bound to this name, or null if none */ - public Tensor get(String name) { return bindings.get(name); } + @Override + public Tensor getTensor(String name) { return bindings.get(name); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 8ade181bdb7..1f6ad050368 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -12,18 +12,18 @@ import java.util.List; /** * A tensor variable name which resolves to a tensor in the context at evaluation time - * + * * @author bratseth */ @Beta public class VariableTensor extends PrimitiveTensorFunction { private final String name; - + public VariableTensor(String name) { this.name = name; } - + @Override public List functionArguments() { return Collections.emptyList(); } @@ -35,7 +35,7 @@ public class VariableTensor extends PrimitiveTensorFunction { @Override public Tensor evaluate(EvaluationContext context) { - return ((MapEvaluationContext)context).get(name); + return context.getTensor(name); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 8f4dbf014a7..191c7988443 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -8,7 +8,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext; /** * A composite tensor function is a tensor function which can be expressed (less tersely) * as a tree of primitive tensor functions. - * + * * @author bratseth */ @Beta diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 1dbb94fdb20..faa0ca36cb6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -15,7 +15,7 @@ import java.util.stream.Collectors; /** * Concatenation of two tensors along an (indexed) dimension - * + * * @author bratseth */ @Beta @@ -74,7 +74,7 @@ public class Concat extends PrimitiveTensorFunction { concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder); return builder.build(); } - + private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { Set otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); @@ -112,7 +112,7 @@ public class Concat extends PrimitiveTensorFunction { Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build(); return tensor.multiply(unitTensor); } - + } /** Returns the type resulting from concatenating a and b */ @@ -144,7 +144,7 @@ public class Concat extends PrimitiveTensorFunction { /** * Combine two addresses, adding the offset to the concat dimension * - * @return the combined address or null if the addresses are incompatible + * @return the combined address or null if the addresses are incompatible * (in some other dimension than the concat dimension) */ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, @@ -161,7 +161,7 @@ public class Concat extends PrimitiveTensorFunction { /** * Returns the an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. - * That is, if the returned array contains n at index i then + * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 4ac7b21ba90..14ed38718ce 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -10,18 +10,18 @@ import java.util.List; /** * A function which returns a constant tensor. - * + * * @author bratseth */ @Beta public class ConstantTensor extends PrimitiveTensorFunction { private final Tensor constant; - + public ConstantTensor(String tensorString) { this.constant = Tensor.from(tensorString); } - + public ConstantTensor(Tensor tensor) { this.constant = tensor; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index bbdbd5c3df1..c75d8ee4753 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -11,19 +11,19 @@ import java.util.stream.Stream; /** * A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere. - * + * * @author bratseth */ public class Diag extends CompositeTensorFunction { private final TensorType type; private final Function, Double> diagFunction; - + public Diag(TensorType type) { this.type = type; this.diagFunction = ScalarFunctions.equal(dimensionNames().collect(Collectors.toList())); } - + @Override public List functionArguments() { return Collections.emptyList(); } @@ -43,7 +43,7 @@ public class Diag extends CompositeTensorFunction { public String toString(ToStringContext context) { return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction; } - + private Stream dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::name); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 6ea73b7f310..e42d25197e2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -15,7 +15,7 @@ import java.util.function.Function; /** * An indexed tensor whose values are generated by a function - * + * * @author bratseth */ @Beta @@ -26,7 +26,7 @@ public class Generate extends PrimitiveTensorFunction { /** * Creates a generated tensor - * + * * @param type the type of the tensor * @param generator the function generating values from a list of ints specifying the indexes of the * tensor cell which will receive the value @@ -39,7 +39,7 @@ public class Generate extends PrimitiveTensorFunction { this.type = type; this.generator = generator; } - + private void validateType(TensorType type) { for (TensorType.Dimension dimension : type.dimensions()) if (dimension.type() != TensorType.Dimension.Type.indexedBound) @@ -58,7 +58,7 @@ public class Generate extends PrimitiveTensorFunction { @Override public PrimitiveTensorFunction toPrimitive() { return this; } - + @Override public Tensor evaluate(EvaluationContext context) { Tensor.Builder builder = Tensor.Builder.of(type); @@ -69,14 +69,14 @@ public class Generate extends PrimitiveTensorFunction { } return builder.build(); } - + private DimensionSizes dimensionSizes(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); for (int i = 0; i < b.dimensions(); i++) b.set(i, type.dimensions().get(i).size().get()); return b.build(); } - + @Override public String toString(ToStringContext context) { return type + "(" + generator + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 9a37127e1f0..ff887e3e9a6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -28,12 +28,12 @@ import java.util.function.DoubleBinaryOperator; * The join tensor operation produces a tensor from the argument tensors containing the set of cells * given by the cross product of the cells of the given tensors, having as values the value produced by * applying the given combinator function on the values from the two source cells. - * + * * @author bratseth */ @Beta public class Join extends PrimitiveTensorFunction { - + private final TensorFunction argumentA, argumentB; private final DoubleBinaryOperator combinator; @@ -56,7 +56,7 @@ public class Join extends PrimitiveTensorFunction { if (aDim.name().equals(bDim.name())) { // include if (aDim.isIndexed() && bDim.isIndexed()) { if (aDim.size().isPresent() || bDim.size().isPresent()) - typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE), + typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE), bDim.size().orElse(Integer.MAX_VALUE))); else typeBuilder.indexed(aDim.name()); @@ -112,11 +112,11 @@ public class Join extends PrimitiveTensorFunction { else return generalJoin(a, b, joinedType); } - + private boolean hasSingleIndexedDimension(Tensor tensor) { return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); } - + private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator aIterator = a.valueIterator(); @@ -138,7 +138,7 @@ public class Join extends PrimitiveTensorFunction { } return builder.build(); } - + /** Join a tensor into a superspace */ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) @@ -150,7 +150,7 @@ public class Join extends PrimitiveTensorFunction { private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build(); - + DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace); IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes); @@ -158,14 +158,14 @@ public class Join extends PrimitiveTensorFunction { // Find dimensions which are only in the supertype Set superDimensionNames = new HashSet<>(superspace.type().dimensionNames()); superDimensionNames.removeAll(subspace.type().dimensionNames()); - + for (Iterator i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) { IndexedTensor.SubspaceIterator subspaceInSuper = i.next(); joinSubspaces(subspace.valueIterator(), subspace.size(), subspaceInSuper, subspaceInSuper.size(), reversedArgumentOrder, builder); } - + return builder.build(); } @@ -224,7 +224,7 @@ public class Join extends PrimitiveTensorFunction { subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get(); return subspaceIndexes; } - + private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { String[] subspaceLabels = new String[subspaceIndexes.length]; for (int i = 0; i < subspaceIndexes.length; i++) @@ -259,7 +259,7 @@ public class Join extends PrimitiveTensorFunction { DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize); // for each combination of dimensions only in a - for (Iterator ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) { + for (Iterator ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) { IndexedTensor.SubspaceIterator aSubspace = ia.next(); // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { @@ -276,7 +276,7 @@ public class Join extends PrimitiveTensorFunction { } } } - + private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set retainDimensions) { PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); for (int i = 0; i < addressType.dimensions().size(); i++) @@ -284,7 +284,7 @@ public class Join extends PrimitiveTensorFunction { builder.add(addressType.dimensions().get(i).name(), address.intLabel(i)); return builder.build(); } - + /** Returns the sizes from the joined sizes which are present in the type argument */ private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) { DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); @@ -295,7 +295,7 @@ public class Join extends PrimitiveTensorFunction { } return builder.build(); } - + private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); @@ -364,7 +364,7 @@ public class Join extends PrimitiveTensorFunction { /** * Returns the an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. - * That is, if the returned array contains n at index i then + * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ @@ -384,7 +384,7 @@ public class Join extends PrimitiveTensorFunction { return TensorAddress.of(joinedLabels); } - /** + /** * Maps the content in the given list to the given array, using the given index map. * * @return true if the mapping was successful, false if one of the destination positions was diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index d322a6ab497..a5e1a016a41 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -32,7 +32,7 @@ public class Map extends PrimitiveTensorFunction { this.argument = argument; this.mapper = mapper; } - + public static TensorType outputType(TensorType inputType) { return inputType; } public TensorFunction argument() { return argument; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 5e102454487..4071917c2b5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -15,15 +15,15 @@ public class Matmul extends CompositeTensorFunction { private final TensorFunction argument1, argument2; private final String dimension; - + public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) { this.argument1 = argument1; this.argument2 = argument2; this.dimension = dimension; } - + public static TensorType outputType(TensorType a, TensorType b, String dimension) { - return Reduce.outputType(Join.outputType(a, b), ImmutableList.of(dimension)); + return Join.outputType(a, b); } @Override @@ -44,7 +44,7 @@ public class Matmul extends CompositeTensorFunction { Reduce.Aggregator.sum, dimension); } - + @Override public String toString(ToStringContext context) { return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java index efb7b9e500c..b7c9a5d2342 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java @@ -8,10 +8,10 @@ import com.yahoo.tensor.Tensor; * A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions. * All tensor implementations must implement all primitive tensor functions. * Primitive tensor functions are fully inspectable. - * + * * @author bratseth */ @Beta public abstract class PrimitiveTensorFunction extends TensorFunction { - + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 457763e97ba..958ef85d1dc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -22,11 +22,11 @@ import java.util.stream.Stream; public class Random extends CompositeTensorFunction { private final TensorType type; - + public Random(TensorType type) { this.type = type; } - + @Override public List functionArguments() { return Collections.emptyList(); } @@ -46,7 +46,7 @@ public class Random extends CompositeTensorFunction { public String toString(ToStringContext context) { return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")"; } - + private Stream dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::toString); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index e2b39a2048d..a56f82b026a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -12,19 +12,19 @@ import java.util.stream.Stream; /** * A tensor generator which returns a tensor of any dimension filled with the sum of the tensor * indexes of each position. - * + * * @author bratseth */ public class Range extends CompositeTensorFunction { private final TensorType type; private final Function, Double> rangeFunction; - + public Range(TensorType type) { this.type = type; this.rangeFunction = ScalarFunctions.sum(dimensionNames().collect(Collectors.toList())); } - + @Override public List functionArguments() { return Collections.emptyList(); } @@ -44,7 +44,7 @@ public class Range extends CompositeTensorFunction { public String toString(ToStringContext context) { return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction; } - + private Stream dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::toString); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index a51df12e522..de9f90a5804 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -19,7 +19,7 @@ import java.util.Objects; import java.util.Set; /** - * The reduce tensor operation returns a tensor produced from the argument tensor where some dimensions + * The reduce tensor operation returns a tensor produced from the argument tensor where some dimensions * are collapsed to a single value using an aggregator function. * * @author bratseth @@ -45,7 +45,7 @@ public class Reduce extends PrimitiveTensorFunction { /** * Creates a reduce function. - * + * * @param argument the tensor to reduce * @param aggregator the aggregator function to use * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced, @@ -69,7 +69,7 @@ public class Reduce extends PrimitiveTensorFunction { } return b.build(); } - + public TensorFunction argument() { return argument; } @Override @@ -91,7 +91,7 @@ public class Reduce extends PrimitiveTensorFunction { public String toString(ToStringContext context) { return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; } - + private String commaSeparated(List list) { StringBuilder b = new StringBuilder(); for (String element : list) @@ -103,7 +103,7 @@ public class Reduce extends PrimitiveTensorFunction { public Tensor evaluate(EvaluationContext context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) - throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + + throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor"); // Special case: Reduce all @@ -112,14 +112,14 @@ public class Reduce extends PrimitiveTensorFunction { return reduceIndexedVector((IndexedTensor)argument); else return reduceAllGeneral(argument); - + // Reduce type TensorType.Builder builder = new TensorType.Builder(); for (TensorType.Dimension dimension : argument.type().dimensions()) if ( ! dimensions.contains(dimension.name())) // keep builder.dimension(dimension); TensorType reducedType = builder.build(); - + // Reduce cells Map aggregatingCells = new HashMap<>(); for (Iterator i = argument.cellIterator(); i.hasNext(); ) { @@ -131,10 +131,10 @@ public class Reduce extends PrimitiveTensorFunction { Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); for (Map.Entry aggregatingCell : aggregatingCells.entrySet()) reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); - + return reducedBuilder.build(); } - + private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) { Set indexesToRemove = new HashSet<>(); for (String dimensionToRemove : this.dimensions) @@ -147,7 +147,7 @@ public class Reduce extends PrimitiveTensorFunction { reducedLabels[reducedLabelIndex++] = address.label(i); return TensorAddress.of(reducedLabels); } - + private Tensor reduceAllGeneral(Tensor argument) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator i = argument.valueIterator(); i.hasNext(); ) @@ -163,7 +163,7 @@ public class Reduce extends PrimitiveTensorFunction { } private static abstract class ValueAggregator { - + private static ValueAggregator ofType(Aggregator aggregator) { switch (aggregator) { case avg : return new AvgAggregator(); @@ -174,22 +174,22 @@ public class Reduce extends PrimitiveTensorFunction { case min : return new MinAggregator(); default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented"); } - + } /** Add a new value to those aggregated by this */ public abstract void aggregate(double value); - + /** Returns the value aggregated by this */ public abstract double aggregatedValue(); - + } - + private static class AvgAggregator extends ValueAggregator { private int valueCount = 0; private double valueSum = 0.0; - + @Override public void aggregate(double value) { valueCount++; @@ -197,7 +197,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public double aggregatedValue() { + public double aggregatedValue() { return valueSum / valueCount; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 6e52760424e..ec9b762a41c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -3,8 +3,6 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -19,7 +17,7 @@ import java.util.Objects; /** * The rename tensor function returns a tensor where some dimensions are assigned new names. - * + * * @author bratseth */ @Beta @@ -28,7 +26,7 @@ public class Rename extends PrimitiveTensorFunction { private final TensorFunction argument; private final List fromDimensions; private final List toDimensions; - + public Rename(TensorFunction argument, String fromDimension, String toDimension) { this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); } @@ -46,7 +44,7 @@ public class Rename extends PrimitiveTensorFunction { this.fromDimensions = ImmutableList.copyOf(fromDimensions); this.toDimensions = ImmutableList.copyOf(toDimensions); } - + @Override public List functionArguments() { return Collections.singletonList(argument); } @@ -66,7 +64,7 @@ public class Rename extends PrimitiveTensorFunction { Map fromToMap = fromToMap(); TensorType renamedType = rename(tensor.type(), fromToMap); - + // an array which lists the index of each label in the renamed type int[] toIndexes = new int[tensor.type().dimensions().size()]; for (int i = 0; i < tensor.type().dimensions().size(); i++) { @@ -74,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction { String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName); toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get(); } - + Tensor.Builder builder = Tensor.Builder.of(renamedType); for (Iterator i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry cell = i.next(); @@ -90,7 +88,7 @@ public class Rename extends PrimitiveTensorFunction { builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); return builder.build(); } - + private TensorAddress rename(TensorAddress address, int[] toIndexes) { String[] reorderedLabels = new String[toIndexes.length]; for (int i = 0; i < toIndexes.length; i++) @@ -99,18 +97,18 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public String toString(ToStringContext context) { - return "rename(" + argument.toString(context) + ", " + + public String toString(ToStringContext context) { + return "rename(" + argument.toString(context) + ", " + toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } - + private Map fromToMap() { Map map = new HashMap<>(); for (int i = 0; i < fromDimensions.size(); i++) map.put(fromDimensions.get(i), toDimensions.get(i)); return map; } - + private String toVectorString(List elements) { if (elements.size() == 1) return elements.get(0); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index cabcce198d1..533a46f87fe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -12,7 +12,7 @@ import java.util.List; * A representation of a tensor function which is able to be translated to a set of primitive * tensor functions if necessary. * All tensor functions are immutable. - * + * * @author bratseth */ @Beta @@ -48,11 +48,11 @@ public abstract class TensorFunction { /** * Return a string representation of this context. - * + * * @param context a context which must be passed to all nexted functions when requesting the string value */ public abstract String toString(ToStringContext context); - + @Override public String toString() { return toString(ToStringContext.empty()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java index e8c425d49e0..416b28afa22 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java @@ -24,7 +24,7 @@ interface BinaryFormat { /** * Deserialize the given binary data into a Tensor object. - * + * * @param type the expected abstract type of the tensor to serialize, or empty to use type information from the data * @param buffer the buffer containing the tensor binary data */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 8b7325ec211..aabb53d1c67 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -16,9 +16,9 @@ import java.util.Optional; * * Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]* * Cell_values = [double, double, double, ...]* - * where values are encoded in order of increasing indexes in each dimension, increasing + * where values are encoded in order of increasing indexes in each dimension, increasing * indexes of later dimensions in the dimension type before earlier. - * + * * @author bratseth */ @Beta @@ -54,7 +54,7 @@ public class DenseBinaryFormat implements BinaryFormat { type = optionalType.get(); TensorType serializedType = decodeType(buffer); if ( ! serializedType.isAssignableTo(type)) - throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + + throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + " cannot be assigned to type " + type); sizes = sizesFromType(serializedType); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 7467554790a..01a1d023f2b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -46,9 +46,9 @@ public class TypedBinaryFormat { return result; } - /** - * Decode some data to a tensor - * + /** + * Decode some data to a tensor + * * @param type the type to decode and validate to, or empty to use the type given in the data * @param buffer the buffer containing the data, use GrowableByteByffer.wrap(byte[]) if you have a byte array * @return the resulting tensor diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index d199dd3a876..abdb3071bf7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -13,14 +13,14 @@ import java.util.stream.Collectors; /** * Microbenchmark of tensor operations. - * + * * @author bratseth */ public class TensorFunctionBenchmark { private final static Random random = new Random(); - - public double benchmark(int iterations, List modelVectors, TensorType.Dimension.Type dimensionType, + + public double benchmark(int iterations, List modelVectors, TensorType.Dimension.Type dimensionType, boolean extraSpace) { Tensor queryVector = vectors(1, 300, dimensionType).get(0); if (extraSpace) { @@ -34,7 +34,7 @@ public class TensorFunctionBenchmark { long totalTime = System.currentTimeMillis() - startTime; return (double)totalTime / (double)iterations; } - + private Tensor unitVector(String dimension) { return Tensor.Builder.of(new TensorType.Builder().indexed(dimension, 1).build()) .cell().label(dimension, 0).value(1).build(); @@ -49,11 +49,11 @@ public class TensorFunctionBenchmark { private double dotProduct(Tensor tensor, List tensors) { double largest = Double.MIN_VALUE; - TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), - new VariableTensor("argument"), (a, b) -> a * b), + TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), + new VariableTensor("argument"), (a, b) -> a * b), Reduce.Aggregator.sum).toPrimitive(); MapEvaluationContext context = new MapEvaluationContext(); - + for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor context.put("argument", tensorElement); double dotProduct = dotProductFunction.evaluate(context).asDouble(); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 30078b4a826..693b0f09351 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -25,7 +25,7 @@ import static org.junit.Assert.fail; /** * Tests tensor functionality - * + * * @author bratseth */ public class TensorTestCase { @@ -108,7 +108,7 @@ public class TensorTestCase { Tensor.diag(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build())); assertEquals(Tensor.from("{ {x:1}:0, {x:3}:1, {x:9}:0 }"), Tensor.from("{ {x:1}:1, {x:3}:5, {x:9}:3 }").argmax("x")); } - + /** Test the same computation made in various ways which are implemented with special-case optimizations */ @Test public void testOptimizedComputation() { @@ -130,7 +130,7 @@ public class TensorTestCase { assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.indexedUnbound), vectors(Type.mapped, 2))); assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2))); assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2))); - + // Test the unoptimized path by joining in another dimension Tensor unitJ = Tensor.Builder.of(new TensorType.Builder().mapped("j").build()).cell().label("j", 0).value(1).build(); Tensor unitK = Tensor.Builder.of(new TensorType.Builder().mapped("k").build()).cell().label("k", 0).value(1).build(); @@ -138,7 +138,7 @@ public class TensorTestCase { Tensor matrixInKSpace = matrix(Type.mapped, 2).get(0).multiply(unitK); assertEquals("Generic computation implementation", 42, (int)dotProduct(vectorInJSpace, Collections.singletonList(matrixInKSpace))); } - + private double dotProduct(Tensor tensor, List tensors) { double sum = 0; TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), @@ -161,7 +161,7 @@ public class TensorTestCase { private Tensor vector(int vectorSize, TensorType.Dimension.Type dimensionType) { return vectors(vectorSize, dimensionType, 1).get(0); } - + /** Create a list of vectors having a single dimension x */ private List vectors(TensorType.Dimension.Type dimensionType, int vectorCount) { return vectors(3, dimensionType, vectorCount); @@ -179,8 +179,8 @@ public class TensorTestCase { } return tensors; } - - /** + + /** * Create a matrix of vectors (in dimension i) where each vector has the dimension x. * This matrix contains the same vectors as returned by createVectors, in a single list element for convenience. */ diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java index fab53218b2c..f11c068bd74 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java @@ -10,12 +10,12 @@ import static org.junit.Assert.assertEquals; * @author bratseth */ public class JoinTestCase { - + /** Test the indexed subspace join optimization */ @Test public void testJoinIndexedSubspace() { Tensor t1, t2; - + t1 = Tensor.from("tensor(x[]):{{x:0}:1.0,{x:1}:2.0}"); t2 = Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:10,{x:1,y:1,z:0}:0.0}"); assertEquals(Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:20.0,{x:1,y:1,z:0}:0.0}"), @@ -34,10 +34,10 @@ public class JoinTestCase { assertEquals(Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:10.0,{x:1,y:1,z:0}:0.0}"), t2.divide(t1)); } - + @Test public void testGeneralJoin() { - assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"), + assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"), Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:4, {x:2}:6 }") .divide(Tensor.from("tensor(y[]):{{y:0}:2}"))); @@ -45,5 +45,5 @@ public class JoinTestCase { Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:6, {x:1,y:0}:8, {x:0,y:1}:20, {x:1,y:1}:24 }") .divide(Tensor.from("tensor(y[],z[]):{ {y:0,z:0}:2, {y:1,z:0}:4, {y:2,z:0}:6 }"))); } - + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java index 8a58cb0bbed..55069eaced7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java @@ -7,7 +7,7 @@ import static org.junit.Assert.assertEquals; /** * Tests translation of composite to primitive tensor function translation. - * + * * @author bratseth */ public class TensorFunctionTestCase { @@ -16,12 +16,12 @@ public class TensorFunctionTestCase { public void testTranslation() { assertTranslated("join({{x:1}:1.0}, reduce({{x:1}:1.0}, sum, x), f(a,b)(a / b))", new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x")); - assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))", + assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))", new Diag(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build())); assertTranslated("join({{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, reduce({{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, max, x), f(a,b)(a==b))", new Argmax(new ConstantTensor("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x")); } - + private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) { assertEquals(expectedTranslation, inputFunction.toPrimitive().toString()); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index 349309a5052..15a872e439f 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -30,7 +30,7 @@ public class DenseBinaryFormatTestCase { assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor(x[1],y[2],z[3]):{{y:0,x:0,z:0}:2.0}"); } - + @Test public void testSerializationToSeparateType() { assertSerialization(Tensor.from("tensor(x[1],y[1]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[],y[])")); @@ -64,7 +64,7 @@ public class DenseBinaryFormatTestCase { private void assertSerialization(Tensor tensor) { assertSerialization(tensor, tensor.type()); } - + private void assertSerialization(Tensor tensor, TensorType expectedType) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor)); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java index b1d7d797b3e..33dfca017f4 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java @@ -84,7 +84,7 @@ public class MixedBinaryFormatTestCase { private void assertSerialization(Tensor tensor) { assertSerialization(tensor, tensor.type()); } - + private void assertSerialization(Tensor tensor, TensorType expectedType) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor)); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java index 68bf59e3ed9..f002637847b 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java @@ -50,7 +50,7 @@ public class SerializationTestCase { JsonNode node = mapper.readTree(test); if (node.has("tensor") && node.has("binary")) { System.out.println("Running test: " + test); - + Tensor tensor = buildTensor(node.get("tensor")); String spec = getSpec(node.get("tensor")); byte[] encodedTensor = TypedBinaryFormat.encode(tensor); @@ -123,7 +123,7 @@ public class SerializationTestCase { private byte[] getBytes(String binaryRepresentation) { return parseHexValue(binaryRepresentation.substring(2)); } - + private byte[] parseHexValue(String s) { final int len = s.length(); byte[] bytes = new byte[len/2]; diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index d17148cf8dc..f895b64379b 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -65,7 +65,7 @@ public class SparseBinaryFormatTestCase { private void assertSerialization(Tensor tensor, TensorType expectedType) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), + Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor)); assertEquals(tensor, decodedTensor); } -- cgit v1.2.3