summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java5
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java20
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java2
-rw-r--r--container-dev/pom.xml24
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java12
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java8
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java2
-rw-r--r--container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java2
-rw-r--r--container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java29
-rw-r--r--container/pom.xml12
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java13
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java6
-rw-r--r--document/src/main/java/com/yahoo/document/DataType.java8
-rw-r--r--document/src/main/java/com/yahoo/document/DocumentTypeManager.java12
-rw-r--r--document/src/main/java/com/yahoo/document/TensorDataType.java6
-rw-r--r--document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java4
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java8
-rw-r--r--document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java2
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java2
-rw-r--r--pom.xml935
-rw-r--r--searchlib/pom.xml15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java44
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java51
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java160
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java94
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java147
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java36
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py89
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt5039
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001bin0 -> 31400 bytes
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.indexbin0 -> 159 bytes
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java13
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java114
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java4
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java164
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java46
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java46
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java54
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java41
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java97
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java6
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java14
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java14
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java10
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java97
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java6
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java2
91 files changed, 7355 insertions, 470 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java b/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java
index cee501841b4..61cab2f6ce7 100644
--- a/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java
+++ b/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java
@@ -4,10 +4,9 @@ package com.yahoo.config.application.api;
import java.util.logging.Level;
/**
- * Used during application deployment to persist and propagate messages to end user
+ * Used during application deployment to propagate messages to the end user
*
- * @author lulf
- * @since 5.1
+ * @author Ulf Lillengen
*/
public interface DeployLogger {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java
index 69e353ceb35..65176006a2a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java
@@ -118,7 +118,7 @@ public class TensorTransformer extends ExpressionTransformer {
private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) {
ExpressionNode arg1 = node.children().get(0);
ExpressionNode arg2 = node.children().get(1);
-
+
TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1);
Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name());
String dimension = ((ReferenceNode) arg2).getName();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java
index c52a5dc465d..f932265cb93 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java
@@ -259,7 +259,7 @@ public final class Attribute implements Cloneable, Serializable {
throw new IllegalArgumentException("Field " + fieldType + " not supported in convertCollectionType");
}
}
-
+
private static Optional<TensorType> convertTensorType(DataType fieldType) {
if ( ! ( fieldType instanceof TensorDataType)) return Optional.empty();
return Optional.of(((TensorDataType)fieldType).getTensorType());
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java
index c8918f39834..8b6df1a87db 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java
@@ -29,7 +29,7 @@ public class ImmutableImportedSDField implements ImmutableSDField {
@Override
public <T extends Expression> boolean containsExpression(Class<T> searchFor) {
- throw createUnsupportedException();
+ throw createUnsupportedException(searchFor.getSimpleName());
}
@Override
@@ -79,9 +79,9 @@ public class ImmutableImportedSDField implements ImmutableSDField {
@Override
public Index getIndex(String name) {
- if (!importedField.fieldName().equals(name)) {
+ if ( ! importedField.fieldName().equals(name)) {
throw new IllegalArgumentException("Getting an index (" + name + ") with different name than the imported field ("
- + importedField.fieldName() + ") is not supported");
+ + importedField.fieldName() + ") is not supported");
}
String targetIndexName = importedField.targetField().getName();
return importedField.targetField().getIndex(targetIndexName);
@@ -104,7 +104,7 @@ public class ImmutableImportedSDField implements ImmutableSDField {
@Override
public ScriptExpression getIndexingScript() {
- throw createUnsupportedException();
+ throw createUnsupportedException("indexing");
}
@Override
@@ -119,12 +119,12 @@ public class ImmutableImportedSDField implements ImmutableSDField {
@Override
public ImmutableSDField getStructField(String name) {
- throw createUnsupportedException();
+ throw createUnsupportedException("struct");
}
@Override
public Collection<? extends ImmutableSDField> getStructFields() {
- throw createUnsupportedException();
+ throw createUnsupportedException("struct");
}
@Override
@@ -134,12 +134,12 @@ public class ImmutableImportedSDField implements ImmutableSDField {
@Override
public Stemming getStemming(Search search) {
- throw createUnsupportedException();
+ throw createUnsupportedException("stemming");
}
@Override
public Ranking getRanking() {
- throw createUnsupportedException();
+ throw createUnsupportedException("ranking");
}
@Override
@@ -158,8 +158,8 @@ public class ImmutableImportedSDField implements ImmutableSDField {
importedField.targetField().getDataType());
}
- private static UnsupportedOperationException createUnsupportedException() {
- return new UnsupportedOperationException("This aspect is not meaningful or relevant for an imported field.");
+ private static UnsupportedOperationException createUnsupportedException(String aspect) {
+ return new UnsupportedOperationException("'" + aspect + "' is not meaningful or relevant for an imported field.");
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java b/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java
index 96a9448739a..9368d6aaa39 100644
--- a/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java
+++ b/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java
@@ -20,7 +20,7 @@ import java.util.Set;
*/
public class DocumentManager {
- public DocumentmanagerConfig.Builder produce(DocumentModel model,
+ public DocumentmanagerConfig.Builder produce(DocumentModel model,
DocumentmanagerConfig.Builder documentConfigBuilder) {
documentConfigBuilder.enablecompression(false);
Set<DataType> handled = new HashSet<>();
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
index 6eeb12ffdd9..ce3c04f41f7 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
@@ -45,7 +45,7 @@ public class ConstantTensorJsonValidator {
throw new IllegalArgumentException("Ranking constant file names must end with either '.json' or '.json.lz4'");
}
}
-
+
private void validateTensor(TensorType type, Reader tensorData) {
wrapIOException(() -> {
this.parser = jsonFactory.createParser(tensorData);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java
index 4a9310799aa..c686f023d5b 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java
@@ -64,7 +64,7 @@ public class RankingConstantsValidator extends Validator {
private void validateRankingConstant(RankingConstant rankingConstant, ApplicationPackage applicationPackage) throws FileNotFoundException {
ApplicationFile tensorApplicationFile = applicationPackage.getFile(Path.fromString(rankingConstant.getFileName()));
- new ConstantTensorJsonValidator().validate(rankingConstant.getFileName(),
+ new ConstantTensorJsonValidator().validate(rankingConstant.getFileName(),
rankingConstant.getTensorType(),
tensorApplicationFile.createReader());
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java
index a5b7d67e377..e1675007bbc 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java
@@ -8,6 +8,7 @@ import com.yahoo.vespa.model.content.Redundancy;
* Builds redundancy config for a content cluster.
*/
public class RedundancyBuilder {
+
Redundancy build(ModelElement clusterXml) {
Integer initialRedundancy = 2;
Integer finalRedundancy = 3;
@@ -37,4 +38,5 @@ public class RedundancyBuilder {
return new Redundancy(initialRedundancy, finalRedundancy, readyCopies);
}
+
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
index 9407c21fee8..960a3b7d6db 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
@@ -173,7 +173,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.tensor3").isPresent());
assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.numeric").isPresent());
}
-
+
private static Optional<String> findProperty(List<Pair<String, String>> properties, String key) {
for (Pair<String, String> property : properties)
if (property.getFirst().equals(key))
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java
index 7cd00e155bb..4600f6ae4c6 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java
@@ -123,7 +123,7 @@ public class ExportingTestCase extends AbstractExportingTestCase {
public void testIndexinfoFieldsets() throws IOException, ParseException {
assertCorrectDeriving("indexinfo_fieldsets");
}
-
+
@Test
public void testStreamingJuniper() throws IOException, ParseException {
assertCorrectDeriving("streamingjuniper");
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index 12bdd8d2b5c..e5693d24f0f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -202,5 +202,5 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
}
return b.toString();
}
-
+
}
diff --git a/container-dev/pom.xml b/container-dev/pom.xml
index f62bbd22690..16006452e61 100644
--- a/container-dev/pom.xml
+++ b/container-dev/pom.xml
@@ -121,6 +121,18 @@
<groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk15on</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>tensorflow</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
@@ -189,6 +201,18 @@
<groupId>xerces</groupId>
<artifactId>xercesImpl</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>tensorflow</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java
index fc1bbace092..1e44a8fa64d 100644
--- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java
+++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java
@@ -1,16 +1,16 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.prelude.fastsearch;
+import com.yahoo.container.search.LegacyEmulationConfig;
+import com.yahoo.data.access.Inspector;
+import com.yahoo.log.LogLevel;
+
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
-import com.yahoo.data.access.Inspector;
-import com.yahoo.container.search.LegacyEmulationConfig;
-
-import com.yahoo.log.LogLevel;
/**
* @author Bjørn Borud
@@ -25,7 +25,7 @@ public abstract class DocsumField {
Map<String, Constructor<? extends DocsumField>> constructors = new HashMap<>();
- void put(String typename, Class<? extends DocsumField> fieldClass)
+ void put(String typename, Class<? extends DocsumField> fieldClass)
throws NoSuchMethodException, SecurityException {
Constructor<? extends DocsumField> constructor = fieldClass.getConstructor(String.class);
constructors.put(typename, constructor);
@@ -106,7 +106,7 @@ public abstract class DocsumField {
public abstract Object decode(ByteBuffer b);
/**
- * Get the number of bytes this field occupies in the given buffer
+ * Get the number of bytes this field occupies in the given buffer
* AND SET(!) the position to the first byte after this field.
*/
public abstract int getLength(ByteBuffer b);
diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java
index 1524a4da426..692e93bed7e 100644
--- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java
+++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java
@@ -109,7 +109,7 @@ public class FastHit extends Hit {
/**
* Returns the explicitly set uri if available, returns "index:[source]/[partid]/[id]" otherwise
- *
+ *
* @return uri of hit
*/
public URI getUri() {
@@ -128,9 +128,9 @@ public class FastHit extends Hit {
}
/**
- * The uri of the index location of this hit ("index:[source]/[partid]/[id]").
+ * The uri of the index location of this hit ("index:[source]/[partid]/[id]").
* This is the uri if no other uri is assigned
- *
+ *
* @return uri to the index.
*/
public URI getIndexUri() {
@@ -215,7 +215,7 @@ public class FastHit extends Hit {
* The empty string ("") if no value is assigned in the document.
*
* <li><b>Dynamic summary string fields</b>: A Java String before JuniperSearcher and a HitField after.</li>
- *
+ *
* <li><b>Numerics</b>: The corresponding numeric Java type.<br>
* If the field has <i>no value</i> assigned in the document,
* the special numeric {@link com.yahoo.search.result.NanNumber#NaN} is returned.
diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java
index e0ca7fbe6e1..d8b38667224 100644
--- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java
+++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java
@@ -13,7 +13,7 @@ import java.util.Optional;
/**
* A tensor field. Tensors are encoded as a data field where the data (following the length)
* is encoded in a tensor binary format defined by com.yahoo.tensor.serialization.TypedBinaryFormat
- *
+ *
* @author bratseth
*/
public class TensorField extends DocsumField implements VariableLengthField {
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
index 0ec15b95b0d..0fd529bf262 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
@@ -16,7 +16,7 @@ import java.util.Optional;
public class TensorFieldType extends FieldType {
// TODO: Require tensor type
-
+
private final Optional<TensorType> type;
/** Creates a tensor field type with optional information about the kind of tensor this will hold */
diff --git a/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java b/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java
index 15a0fd60511..5494d1965f8 100644
--- a/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java
+++ b/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java
@@ -102,7 +102,7 @@ public class SlimeSummaryTestCase {
public void testDecoding() {
Tensor tensor1 = Tensor.from("tensor(x{},y{}):{{x:foo,y:bar}:0.1}");
Tensor tensor2 = Tensor.from("tensor(x[],y[1]):{{x:0,y:0}:-0.3}");
-
+
String summary_cf = "file:src/test/java/com/yahoo/prelude/fastsearch/summary.cfg";
DocsumDefinitionSet set = createDocsumDefinitionSet(summary_cf);
byte[] docsum = makeDocsum(tensor1, tensor2);
diff --git a/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java b/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java
index e59c03b33c3..62eacaa0afe 100644
--- a/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java
@@ -2,8 +2,6 @@
package com.yahoo.search.test;
import com.yahoo.component.chain.Chain;
-import com.yahoo.language.Language;
-import com.yahoo.language.Linguistics;
import com.yahoo.language.detect.Detection;
import com.yahoo.language.detect.Detector;
import com.yahoo.language.detect.Hint;
@@ -28,7 +26,6 @@ import com.yahoo.search.query.profile.QueryProfile;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.search.result.Hit;
import com.yahoo.search.searchchain.Execution;
-
import com.yahoo.yolean.Exceptions;
import org.junit.Ignore;
import org.junit.Test;
@@ -45,14 +42,14 @@ import java.util.stream.Collectors;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.is;
-import static org.junit.Assert.assertNotEquals;
-import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
@@ -69,7 +66,7 @@ public class QueryTestCase {
assertEquals("", q.properties().get("aParameter"));
assertNull(q.properties().get("notSetParameter"));
}
-
+
// TODO: YQL work in progress (jon)
@Ignore
@Test
@@ -693,7 +690,7 @@ public class QueryTestCase {
List<IndexedItem> l = QueryTree.getPositiveTerms(i);
assertEquals(3, l.size());
}
-
+
@Test
public void testHeuristicLanguageDetectionTextExtraction() {
assertDetectionText("b ", "a:b", "text:a", "text:default");
@@ -720,27 +717,27 @@ public class QueryTestCase {
q.getModel().getQueryTree(); // cause parsing
assertEquals(expectedDetectionText, mockLinguistics.detector.lastDetectionText);
}
-
+
/** A linguistics instance which records the last language detection text passed to it */
private static class MockLinguistics extends SimpleLinguistics {
final MockDetector detector = new MockDetector();
-
+
@Override
public Detector getDetector() { return detector; }
-
+
}
-
+
private static class MockDetector extends SimpleDetector {
String lastDetectionText = null;
-
+
@Override
public Detection detect(String input, Hint hint) {
lastDetectionText = input;
return super.detect(input, hint);
}
-
+
}
protected boolean contains(String lineSubstring,String[] lines) {
diff --git a/container/pom.xml b/container/pom.xml
index 3793a3508a4..d252a5eee4a 100644
--- a/container/pom.xml
+++ b/container/pom.xml
@@ -47,6 +47,18 @@
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>tensorflow</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
</exclusions>
</dependency>
</dependencies>
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java
index 1b2ad9f938a..fb675862320 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java
@@ -24,6 +24,7 @@ import java.util.Objects;
* @author smorgrav
*/
public class ClusterCost {
+
private final double tco;
private final double waste;
private final ClusterInfo clusterInfo;
@@ -32,8 +33,8 @@ public class ClusterCost {
private final ClusterUtilization resultUtilization;
/**
- * @param clusterInfo Value object with cluster info e.g. the TCO for the hardware used
- * @param systemUtilization Utilization of system resources (as ratios)
+ * @param clusterInfo value object with cluster info e.g. the TCO for the hardware used
+ * @param systemUtilization utilization of system resources (as ratios)
*/
public ClusterCost(ClusterInfo clusterInfo,
ClusterUtilization systemUtilization) {
@@ -79,10 +80,10 @@ public class ClusterCost {
}
static ClusterUtilization calculateResultUtilization(ClusterUtilization system, ClusterUtilization target) {
- double cpu = ratio(system.getCpu(),target.getCpu());
- double mem = ratio(system.getMemory(),target.getMemory());
- double disk = ratio(system.getDisk(),target.getDisk());
- double diskbusy = ratio(system.getDiskBusy(),target.getDiskBusy());
+ double cpu = ratio(system.getCpu(), target.getCpu());
+ double mem = ratio(system.getMemory(), target.getMemory());
+ double disk = ratio(system.getDisk(), target.getDisk());
+ double diskbusy = ratio(system.getDiskBusy(), target.getDiskBusy());
return new ClusterUtilization(mem, cpu, disk, diskbusy);
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java
index 585690793bb..371e1c41e32 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java
@@ -44,17 +44,17 @@ public class DeploymentCost {
return clusters;
}
- /** @return Total cost of ownership for the deployment (sum of all clusters) */
+ /** Returns the total monthly cost of ownership for the deployment (sum of all clusters) */
public double getTco() {
return tco;
}
- /** @return The utilization of clusters that wastes most money in this deployment */
+ /** Returns the utilization of clusters that wastes most money in this deployment */
public double getUtilization() {
return utilization;
}
- /** @return The amount of dollars spent and not utilized */
+ /** Returns the amount of dollars spent and not utilized */
public double getWaste() {
return waste;
}
diff --git a/document/src/main/java/com/yahoo/document/DataType.java b/document/src/main/java/com/yahoo/document/DataType.java
index c8a04866aa9..abdbf394591 100644
--- a/document/src/main/java/com/yahoo/document/DataType.java
+++ b/document/src/main/java/com/yahoo/document/DataType.java
@@ -51,7 +51,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com
public final static PrimitiveDataType URI = new PrimitiveDataType("uri", 10, UriFieldValue.class, new UriFieldValue.Factory());
public final static NumericDataType BYTE = new NumericDataType("byte", 16, ByteFieldValue.class, ByteFieldValue.getFactory());
public final static PrimitiveDataType PREDICATE = new PrimitiveDataType("predicate", 20, PredicateFieldValue.class, PredicateFieldValue.getFactory());
- public final static int tensorDataTypeCode = 21; // All TensorDataType instances have id=21 but carries additional type information serialized separately
+ public final static int tensorDataTypeCode = 21; // All TensorDataType instances have id=21 but carries additional type information serialized separately
// ADDITIONAL parametrized types added at runtime: map, struct, array, weighted set, annotation reference, tensor
// Tags are converted to weightedset<string> when reading the search definition TODO: Remove it
@@ -99,7 +99,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com
/**
* Creates a field value by reflection
- *
+ *
* @param arg the value of the newly created field value
* @return a fully constructed value
*/
@@ -201,7 +201,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com
public static TensorDataType getTensor(TensorType type) {
return new TensorDataType(type);
}
-
+
public String getName() {
return name;
}
@@ -267,7 +267,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com
*/
public FieldPath buildFieldPath(String fieldPathString) {
if (fieldPathString.length() > 0) {
- throw new IllegalArgumentException("Datatype " + toString() +
+ throw new IllegalArgumentException("Datatype " + toString() +
" does not support further recursive structure: " + fieldPathString);
}
return new FieldPath();
diff --git a/document/src/main/java/com/yahoo/document/DocumentTypeManager.java b/document/src/main/java/com/yahoo/document/DocumentTypeManager.java
index 8c9318199d8..5fad35a2287 100644
--- a/document/src/main/java/com/yahoo/document/DocumentTypeManager.java
+++ b/document/src/main/java/com/yahoo/document/DocumentTypeManager.java
@@ -38,7 +38,7 @@ public class DocumentTypeManager {
// *Configured data types* (not built-in/primitive) indexed by their id
//
// *Primitive* data types are always available and have a single id.
- //
+ //
// *Built-in dynamic* types: The tensor type.
// Any tensor type has the same id and is always available just like primitive types.
// However, unlike primitive types, each tensor type is a separate DataType instance
@@ -112,7 +112,7 @@ public class DocumentTypeManager {
public DataType getDataType(String name) {
if (name.startsWith("tensor(")) // built-in dynamic
return new TensorDataType(TensorType.fromSpec(name));
-
+
List<DataType> foundTypes = new ArrayList<>();
for (DataType type : dataTypes.values()) {
if (type.getName().equalsIgnoreCase(name)) {
@@ -141,10 +141,10 @@ public class DocumentTypeManager {
}
public DataType getDataType(int code) { return getDataType(code, ""); }
-
+
/**
* Return a data type instance
- *
+ *
* @param code the code of the data type to return, which must be either built in or present in this manager
* @param detailedType detailed type information, or the empty string if none
* @return the appropriate DataType instance
@@ -183,7 +183,7 @@ public class DocumentTypeManager {
/**
* Register a single datatype. Re-registering an existing, but equal, datatype is ok.
- *
+ *
* @param type The datatype to register
*/
void registerSingleType(DataType type) {
@@ -280,7 +280,7 @@ public class DocumentTypeManager {
/**
* Returns a read only view of the registered data types
- *
+ *
* @return collection of types
*/
public Collection<DataType> getDataTypes() {
diff --git a/document/src/main/java/com/yahoo/document/TensorDataType.java b/document/src/main/java/com/yahoo/document/TensorDataType.java
index aefdc030a12..50e9cf0f60f 100644
--- a/document/src/main/java/com/yahoo/document/TensorDataType.java
+++ b/document/src/main/java/com/yahoo/document/TensorDataType.java
@@ -8,13 +8,13 @@ import com.yahoo.vespa.objects.Ids;
/**
* A DataType containing a tensor type
- *
+ *
* @author bratseth
*/
public class TensorDataType extends DataType {
private final TensorType tensorType;
-
+
// The global class identifier shared with C++.
public static int classId = registerClass(Ids.document + 59, TensorDataType.class);
@@ -47,5 +47,5 @@ public class TensorDataType extends DataType {
/** Returns the type of the tensor this field can hold */
public TensorType getTensorType() { return tensorType; }
-
+
}
diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
index ae8d5cf596a..1808396986e 100644
--- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
+++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
@@ -19,7 +19,7 @@ import java.util.Optional;
public class TensorFieldValue extends FieldValue {
private Optional<Tensor> tensor;
-
+
private final TensorDataType dataType;
/** Create an empty tensor field value */
@@ -66,7 +66,7 @@ public class TensorFieldValue extends FieldValue {
o.getClass().getName() + "'.");
}
}
-
+
public void assignTensor(Optional<Tensor> tensor) {
if (tensor.isPresent() && ! tensor.get().type().isAssignableTo(dataType.getTensorType()))
throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() +
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index f37fa5ea675..29ba244a9f1 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -146,9 +146,9 @@ public class JsonReaderTestCase {
}
{
DocumentType x = new DocumentType("testtensor");
- x.addField(new Field("mappedtensorfield",
+ x.addField(new Field("mappedtensorfield",
new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").build())));
- x.addField(new Field("indexedtensorfield",
+ x.addField(new Field("indexedtensorfield",
new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build())));
types.registerDocumentType(x);
}
@@ -1280,8 +1280,8 @@ public class JsonReaderTestCase {
return (DocumentPut) reader.next();
}
- private DocumentPut createPutWithMappedTensor(String inputTensor) {
- return createPutWithTensor(inputTensor, "mappedtensorfield");
+ private DocumentPut createPutWithMappedTensor(String inputTensor) {
+ return createPutWithTensor(inputTensor, "mappedtensorfield");
}
private DocumentPut createPutWithTensor(String inputTensor, String tensorFieldName) {
InputStream rawDoc = new ByteArrayInputStream(
diff --git a/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java b/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java
index 7104c1686f8..5c65b11a0c4 100644
--- a/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java
+++ b/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java
@@ -24,7 +24,7 @@ public class TensorFieldValueSerializationTestCase {
private final static TensorType tensorType = new TensorType.Builder().mapped("dimX").mapped("dimY").build();
private final static String TENSOR_FIELD = "my_tensor";
private final static String TENSOR_FILES = "src/test/resources/tensor/";
- private final static TestDocumentFactory docFactory = new TestDocumentFactory(createDocType(),
+ private final static TestDocumentFactory docFactory = new TestDocumentFactory(createDocType(),
"id:test:my_type::foo");
private static DocumentType createDocType() {
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java
index 0f08bf0bf21..9ef1a3f6e32 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java
@@ -56,7 +56,7 @@ public abstract class FieldUpdateHelper {
} else if (upd instanceof ArithmeticValueUpdate) {
if (((ArithmeticValueUpdate)upd).getOperator() == ArithmeticValueUpdate.Operator.DIV &&
((ArithmeticValueUpdate)upd).getOperand().doubleValue() == 0) {
- throw new IllegalArgumentException("Division by zero.");
+ throw new IllegalArgumentException("Div by zero.");
}
val.assign(upd.getValue());
return val;
diff --git a/pom.xml b/pom.xml
index eb1f954ce13..b196034380b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -23,6 +23,941 @@
</developer>
</developers>
+ <distributionManagement>
+ <repository>
+ <id>bintray-vespa-repo</id>
+ <url>https://api.bintray.com/maven/yahoo/maven/vespa;publish=1</url>
+ </repository>
+ </distributionManagement>
+
+ <repositories>
+ <!-- Required for Athenz libraries -->
+ <repository>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
+ <id>bintray-yahoo-maven</id>
+ <name>bintray</name>
+ <url>https://yahoo.bintray.com/maven</url>
+ </repository>
+ </repositories>
+
+ <scm>
+ <connection>scm:git:git@github.com:vespa-engine/vespa.git</connection>
+ <developerConnection>scm:git:git@github.com:vespa-engine/vespa.git</developerConnection>
+ <url>git@github.com:vespa-engine/vespa.git</url>
+ </scm>
+
+ <build>
+ <finalName>${project.artifactId}</finalName>
+ <extensions>
+ <extension>
+ <groupId>org.apache.maven.wagon</groupId>
+ <artifactId>wagon-ssh-external</artifactId>
+ <version>2.7</version>
+ </extension>
+ <extension>
+ <groupId>org.apache.maven.archetype</groupId>
+ <artifactId>archetype-packaging</artifactId>
+ <version>2.0</version>
+ </extension>
+ </extensions>
+ <pluginManagement>
+ <plugins>
+ <plugin>
+ <groupId>org.antlr</groupId>
+ <artifactId>antlr3-maven-plugin</artifactId>
+ <version>${antlr.version}</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-antrun-plugin</artifactId>
+ <version>1.7</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.felix</groupId>
+ <artifactId>maven-bundle-plugin</artifactId>
+ <version>2.4.0</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-assembly-plugin</artifactId>
+ <version>2.4</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <version>3.6.1</version>
+ <configuration>
+ <source>1.8</source>
+ <target>1.8</target>
+ <showWarnings>true</showWarnings>
+ <optimize>true</optimize>
+ <showDeprecation>false</showDeprecation>
+ <compilerArgs>
+ <arg>-Xlint:all</arg>
+ <arg>-Xlint:-serial</arg>
+ <arg>-Xlint:-try</arg>
+ <arg>-Xlint:-processing</arg>
+ <arg>-Xlint:-varargs</arg>
+ <arg>-Werror</arg>
+ </compilerArgs>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-dependency-plugin</artifactId>
+ <version>2.10</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-deploy-plugin</artifactId>
+ <version>2.5</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-install-plugin</artifactId>
+ <version>2.5.2</version>
+ <configuration>
+ <updateReleaseInfo>true</updateReleaseInfo>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <version>3.0.2</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-javadoc-plugin</artifactId>
+ <configuration>
+ <additionalparam>-Xdoclint:${doclint} -Xdoclint:-missing</additionalparam>
+ </configuration>
+ <version>2.10.4</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-plugin-plugin</artifactId>
+ <version>3.5</version>
+ <configuration>
+ <!-- see http://jira.codehaus.org/browse/MNG-5346 -->
+ <skipErrorNoDescriptorsFound>true</skipErrorNoDescriptorsFound>
+ </configuration>
+ <executions>
+ <execution>
+ <id>mojo-descriptor</id>
+ <goals>
+ <goal>descriptor</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-resources-plugin</artifactId>
+ <version>2.7</version>
+ <configuration>
+ <escapeString>\</escapeString>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-site-plugin</artifactId>
+ <version>3.3</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-source-plugin</artifactId>
+ <version>2.1.2</version>
+ <configuration>
+ <includePom>true</includePom>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-surefire-plugin</artifactId>
+ <version>${surefire.version}</version>
+ <configuration>
+ <redirectTestOutputToFile>${test.hide}</redirectTestOutputToFile>
+ <systemPropertyVariables>
+ <java.io.tmpdir>${project.build.directory}</java.io.tmpdir>
+ </systemPropertyVariables>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-surefire-report-plugin</artifactId>
+ <version>${surefire.version}</version>
+ <configuration>
+ <alwaysGenerateSurefireReport>false</alwaysGenerateSurefireReport>
+ <showSuccess>false</showSuccess>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>build-helper-maven-plugin</artifactId>
+ <version>1.9.1</version>
+ </plugin>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>exec-maven-plugin</artifactId>
+ <version>1.6.0</version>
+ </plugin>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>javacc-maven-plugin</artifactId>
+ <version>2.6</version>
+ </plugin>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>properties-maven-plugin</artifactId>
+ <version>1.0.0</version>
+ </plugin>
+ <plugin>
+ <groupId>net.alchim31.maven</groupId>
+ <artifactId>scala-maven-plugin</artifactId>
+ <version>3.2.2</version>
+ <configuration>
+ <args>
+ <arg>-unchecked</arg>
+ <arg>-deprecation</arg>
+ <arg>-feature</arg>
+ <arg>-Xfatal-warnings</arg>
+ </args>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>bundle-plugin</artifactId>
+ <version>${project.version}</version>
+ <configuration>
+ <configGenVersion>${project.version}</configGenVersion>
+ <useCommonAssemblyIds>true</useCommonAssemblyIds>
+ </configuration>
+ </plugin>
+ </plugins>
+ </pluginManagement>
+ </build>
+ <profiles>
+ <profile>
+ <id>attach-sources</id>
+ <activation>
+ <property>
+ <name>!skipSources</name>
+ </property>
+ </activation>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-source-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>attach-sources</id>
+ <goals>
+ <goal>jar-no-fork</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>generate-javadoc</id>
+ <activation>
+ <property>
+ <name>!skipJavadoc</name>
+ </property>
+ </activation>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-javadoc-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>generate-javadoc</id>
+ <phase>package</phase>
+ <goals>
+ <goal>javadoc</goal>
+ </goals>
+ </execution>
+ </executions>
+ <configuration>
+ <additionalparam>-Xdoclint:${doclint} -Xdoclint:-missing</additionalparam>
+ <failOnError>${javadoc.failOnError}</failOnError>
+ <quiet>true</quiet>
+ <show>private</show>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>coverage</id>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>exec-maven-plugin</artifactId>
+ <configuration>
+ <includePluginDependencies>true</includePluginDependencies>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>build-helper-maven-plugin</artifactId>
+ <executions>
+ <execution>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>add-source</goal>
+ </goals>
+ <configuration>
+ <sources>
+ <source>src/main/scala</source>
+ </sources>
+ </configuration>
+ </execution>
+ <execution>
+ <id>add-test-source</id>
+ <phase>generate-test-sources</phase>
+ <goals>
+ <goal>add-test-source</goal>
+ </goals>
+ <configuration>
+ <sources>
+ <source>src/test/scala</source>
+ </sources>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>sign-artifacts</id>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ <version>1.6</version>
+ <executions>
+ <execution>
+ <id>sign-artifacts</id>
+ <phase>verify</phase>
+ <goals>
+ <goal>sign</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+ <dependencyManagement>
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.maven.wagon</groupId>
+ <artifactId>wagon-ssh-external</artifactId>
+ <version>2.7</version>
+ </dependency>
+ <dependency>
+ <groupId>com.github.cverges.expect4j</groupId>
+ <artifactId>expect4j</artifactId>
+ <version>1.6</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-compress</artifactId>
+ <version>1.11</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-exec</artifactId>
+ <version>1.3</version>
+ </dependency>
+ <dependency>
+ <groupId>io.airlift</groupId>
+ <artifactId>airline</artifactId>
+ <version>0.7</version>
+ </dependency>
+ <dependency>
+ <groupId>aopalliance</groupId>
+ <artifactId>aopalliance</artifactId>
+ <version>1.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.ow2.asm</groupId>
+ <artifactId>asm</artifactId>
+ <version>5.2</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.code.findbugs</groupId>
+ <artifactId>annotations</artifactId>
+ <version>1.3.9</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.code.findbugs</groupId>
+ <artifactId>jsr305</artifactId>
+ <version>1.3.9</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <version>18.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava-testlib</artifactId>
+ <version>18.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.inject</groupId>
+ <artifactId>guice</artifactId>
+ <version>3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.inject</groupId>
+ <artifactId>guice</artifactId>
+ <version>3.0</version>
+ <classifier>no_aop</classifier>
+ </dependency>
+ <dependency>
+ <groupId>com.google.inject.extensions</groupId>
+ <artifactId>guice-assistedinject</artifactId>
+ <version>3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.inject.extensions</groupId>
+ <artifactId>guice-multibindings</artifactId>
+ <version>3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>3.4.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.googlecode.jmockit</groupId>
+ <artifactId>jmockit</artifactId>
+ <version>1.2</version>
+ </dependency>
+ <dependency>
+ <groupId>com.goldmansachs</groupId>
+ <artifactId>gs-collections</artifactId>
+ <version>6.1.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-core</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-annotations</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.jaxrs</groupId>
+ <artifactId>jackson-jaxrs-json-provider</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.module</groupId>
+ <artifactId>jackson-module-jaxb-annotations</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.jaxrs</groupId>
+ <artifactId>jackson-jaxrs-base</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.jaxrs</groupId>
+ <artifactId>jackson-jaxrs-xml-provider</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.dataformat</groupId>
+ <artifactId>jackson-dataformat-xml</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.datatype</groupId>
+ <artifactId>jackson-datatype-jdk8</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.datatype</groupId>
+ <artifactId>jackson-datatype-jsr310</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.infradna.tool</groupId>
+ <artifactId>bridge-method-annotation</artifactId>
+ <version>1.4</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-cli</groupId>
+ <artifactId>commons-cli</artifactId>
+ <version>1.3.1</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-codec</groupId>
+ <artifactId>commons-codec</artifactId>
+ <version>1.4</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-collections</groupId>
+ <artifactId>commons-collections</artifactId>
+ <version>3.2.1</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-configuration</groupId>
+ <artifactId>commons-configuration</artifactId>
+ <version>1.6</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-daemon</groupId>
+ <artifactId>commons-daemon</artifactId>
+ <version>1.0.3</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-io</groupId>
+ <artifactId>commons-io</artifactId>
+ <version>2.4</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-lang</groupId>
+ <artifactId>commons-lang</artifactId>
+ <version>${commons-lang.version}</version>
+ </dependency>
+ <dependency>
+ <!-- This version is exported by jdisc via jcl-over-slf4j. -->
+ <groupId>commons-logging</groupId>
+ <artifactId>commons-logging</artifactId>
+ <version>1.1.1</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-net</groupId>
+ <artifactId>commons-net</artifactId>
+ <version>2.0</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-pool</groupId>
+ <artifactId>commons-pool</artifactId>
+ <version>1.5.6</version>
+ </dependency>
+ <!-- Explicitly included to get Zookeeper version 3.4.10,
+ can be excluded if you want the Zookeeper version
+ used by curator by default
+ -->
+ <dependency>
+ <groupId>org.apache.zookeeper</groupId>
+ <artifactId>zookeeper</artifactId>
+ <version>3.4.10</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.curator</groupId>
+ <artifactId>curator-recipes</artifactId>
+ <version>${curator.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.curator</groupId>
+ <artifactId>curator-test</artifactId>
+ <version>${curator.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>javax.servlet</groupId>
+ <artifactId>javax.servlet-api</artifactId>
+ <version>3.1.0</version>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <version>4.12</version>
+ </dependency>
+ <dependency>
+ <groupId>org.antlr</groupId>
+ <artifactId>antlr-runtime</artifactId>
+ <version>${antlr.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.antlr</groupId>
+ <artifactId>antlr4-runtime</artifactId>
+ <version>${antlr4.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.aries.spifly</groupId>
+ <artifactId>org.apache.aries.spifly.dynamic.bundle</artifactId>
+ <version>${aries.spifly.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-lang3</artifactId>
+ <version>3.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.felix</groupId>
+ <artifactId>org.apache.felix.framework</artifactId>
+ <version>4.2.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.felix</groupId>
+ <artifactId>org.apache.felix.log</artifactId>
+ <version>1.0.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.felix</groupId>
+ <artifactId>org.apache.felix.main</artifactId>
+ <version>4.2.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>fluent-hc</artifactId>
+ <version>4.3.6</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpclient</artifactId>
+ <version>4.3.6</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpcore</artifactId>
+ <version>4.3.3</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpmime</artifactId>
+ <version>4.3.6</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven</groupId>
+ <artifactId>maven-artifact</artifactId>
+ <version>3.5.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven</groupId>
+ <artifactId>maven-core</artifactId>
+ <version>3.5.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven</groupId>
+ <artifactId>maven-model</artifactId>
+ <version>3.5.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven.plugin-tools</groupId>
+ <artifactId>maven-plugin-annotations</artifactId>
+ <version>3.5</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven</groupId>
+ <artifactId>maven-plugin-api</artifactId>
+ <version>3.5.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven</groupId>
+ <artifactId>maven-project</artifactId>
+ <version>2.2.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <version>3.0.2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven.surefire</groupId>
+ <artifactId>surefire-junit4</artifactId>
+ <version>${surefire.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven.surefire</groupId>
+ <artifactId>surefire-providers</artifactId>
+ <version>${surefire.version}</version>
+ <type>pom</type>
+ </dependency>
+ <dependency>
+ <groupId>org.codehaus.jettison</groupId>
+ <artifactId>jettison</artifactId>
+ <version>1.3.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.cthul</groupId>
+ <artifactId>cthul-matchers</artifactId>
+ <version>1.0</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-continuation</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-server</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-servlet</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-servlets</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-util</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-http</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-jmx</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-all</artifactId>
+ <version>1.3</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-core</artifactId>
+ <version>1.3</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-library</artifactId>
+ <version>1.3</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>uk.co.datumedge</groupId>
+ <artifactId>hamcrest-json</artifactId>
+ <version>0.2</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.hdrhistogram</groupId>
+ <artifactId>HdrHistogram</artifactId>
+ <version>2.1.8</version>
+ </dependency>
+ <dependency>
+ <groupId>org.json</groupId>
+ <artifactId>json</artifactId>
+ <version>20090211</version>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <version>1.9.5</version>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <version>1.9.5</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.osgi</groupId>
+ <artifactId>org.osgi.compendium</artifactId>
+ <version>4.3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.osgi</groupId>
+ <artifactId>org.osgi.core</artifactId>
+ <version>4.3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.scala-lang</groupId>
+ <artifactId>scala-library</artifactId>
+ <version>${scala.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.scala-lang.modules</groupId>
+ <artifactId>scala-parser-combinators_${scala.major-version}</artifactId>
+ <version>1.0.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.scala-lang.modules</groupId>
+ <artifactId>scala-xml_${scala.major-version}</artifactId>
+ <version>1.0.2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.major-version}</artifactId>
+ <version>2.2.2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>jcl-over-slf4j</artifactId>
+ <version>1.7.5</version>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>log4j-over-slf4j</artifactId>
+ <version>1.7.5</version>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ <version>1.7.5</version>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-jdk14</artifactId>
+ <version>1.7.5</version>
+ </dependency>
+ <dependency>
+ <groupId>org.springframework</groupId>
+ <artifactId>spring-test</artifactId>
+ <version>4.0.6.RELEASE</version>
+ </dependency>
+ <dependency>
+ <groupId>org.testng</groupId>
+ <artifactId>testng</artifactId>
+ <version>6.10</version>
+ </dependency>
+ <dependency>
+ <groupId>org.twdata.maven</groupId>
+ <artifactId>mojo-executor</artifactId>
+ <version>2.3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>net.jcip</groupId>
+ <artifactId>jcip-annotations</artifactId>
+ <version>1.0</version>
+ </dependency>
+ <dependency>
+ <groupId>net.jpountz.lz4</groupId>
+ <artifactId>lz4</artifactId>
+ <version>1.3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>net.spy</groupId>
+ <artifactId>spymemcached</artifactId>
+ <version>2.10.1</version>
+ </dependency>
+ <dependency>
+ <groupId>xerces</groupId>
+ <artifactId>xercesImpl</artifactId>
+ <version>2.11.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.bouncycastle</groupId>
+ <artifactId>bcpkix-jdk15on</artifactId>
+ <version>${bouncycastle.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.bouncycastle</groupId>
+ <artifactId>bcprov-jdk15on</artifactId>
+ <version>${bouncycastle.version}</version>
+ </dependency>
+ <!-- jersey 2 support -->
+ <dependency>
+ <groupId>javax.ws.rs</groupId>
+ <artifactId>javax.ws.rs-api</artifactId>
+ <version>${javax.ws.rs-api.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.glassfish.jersey.containers</groupId>
+ <artifactId>jersey-container-servlet-core</artifactId>
+ <version>${jersey2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.glassfish.jersey.containers</groupId>
+ <artifactId>jersey-container-servlet</artifactId>
+ <version>${jersey2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.glassfish.jersey.media</groupId>
+ <artifactId>jersey-media-json-jackson</artifactId>
+ <version>${jersey2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.glassfish.jersey.media</groupId>
+ <artifactId>jersey-media-multipart</artifactId>
+ <version>${jersey2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.glassfish.jersey.ext</groupId>
+ <artifactId>jersey-proxy-client</artifactId>
+ <version>${jersey2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.glassfish.jersey.core</groupId>
+ <artifactId>jersey-client</artifactId>
+ <version>${jersey2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.ibm.icu</groupId>
+ <artifactId>icu4j</artifactId>
+ <version>57.1</version>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.athenz</groupId>
+ <artifactId>athenz-zms-java-client</artifactId>
+ <version>${athenz.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.athenz</groupId>
+ <artifactId>athenz-zts-java-client</artifactId>
+ <version>${athenz.version}</version>
+ </dependency>
+ </dependencies>
+ </dependencyManagement>
+
+ <properties>
+ <javax.ws.rs-api.version>2.0.1</javax.ws.rs-api.version> <!-- must be kept in sync with version used by current jersey2.version -->
+ <antlr.version>3.5.2</antlr.version>
+ <antlr4.version>4.5</antlr4.version>
+ <aries.spifly.version>1.0.8</aries.spifly.version>
+ <aries.util.version>1.0.0</aries.util.version>
+ <asm-debug-all.version>5.0.3</asm-debug-all.version>
+ <!-- Athenz dependencies. Make sure these dependencies matches those in Vespa's internal repositories -->
+ <athenz.version>1.7.28</athenz.version>
+ <bouncycastle.version>1.58</bouncycastle.version>
+ <commons-lang.version>2.6</commons-lang.version>
+ <!-- WARNING: If you change curator version, you also need to update
+ zkfacade/src/main/java/org/apache/curator/**/package-info.java
+ using something like
+ find zkfacade/src/main/java/org/apache/curator -name package-info.java | \
+ xargs perl -pi -e 's/major = [0-9]+, minor = [0-9]+, micro = [0-9]+/major = 2, minor = 9, micro = 1/g'
+ -->
+ <curator.version>2.9.1</curator.version>
+ <jackson2.version>2.8.3</jackson2.version>
+ <jersey2.version>2.23.2</jersey2.version>
+ <jetty.version>9.4.6.v20170531</jetty.version>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
+ <test.hide>true</test.hide>
+ <doclint>all</doclint>
+ <scala.major-version>2.11</scala.major-version>
+ <scala.version>${scala.major-version}.4</scala.version>
+ <surefire.version>2.19.1</surefire.version> <!-- NOTE bjorncs 15.06.2017: Version 2.20 has OoM issues -->
+ </properties>
+
<modules>
<module>application</module>
<module>application-deploy-plugin</module>
diff --git a/searchlib/pom.xml b/searchlib/pom.xml
index 5f6717d9516..09ccf9928b7 100644
--- a/searchlib/pom.xml
+++ b/searchlib/pom.xml
@@ -36,6 +36,21 @@
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>3.4.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ <version>1.4.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>tensorflow</artifactId>
+ <version>1.4.0</version>
+ </dependency>
+ <dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<scope>test</scope>
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 785ed78492e..0eeb0a9e630 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Set;
@@ -18,26 +19,30 @@ public abstract class Context implements EvaluationContext {
/**
* <p>Returns the value of a simple variable name.</p>
*
- * @param name The name of the variable whose value to return.
- * @return The value of the named variable.
+ * @param name the name of the variable whose value to return.
+ * @return the value of the named variable.
*/
public abstract Value get(String name);
+ /** Returns a variable as a tensor */
+ @Override
+ public Tensor getTensor(String name) { return get(name).asTensor(); }
+
/**
* <p>Returns the value of a <i>structured variable</i> on the form
* <code>name(argument*)(.output)?</code>, where <i>argument</i> is any
* string. This may be used to implement more advanced variables whose
* values are calculated at runtime from arguments. Supporting this in a
- * context is optional.
- *
+ * context is optional.
+ *
* <p>This default implementation generates a name on the form
* <code>name(argument1, argument2, ...argumentN).output</code>.
* If there are no arguments the parenthesis are omitted.
* If there is no output, the dot is omitted.</p>
*
- * @param name The name of this variable.
- * @param arguments The parsed arguments as given in the textual expression.
- * @param output The name of the value to output (to enable one named
+ * @param name the name of this variable.
+ * @param arguments the parsed arguments as given in the textual expression.
+ * @param output the name of the value to output (to enable one named
* calculation to output several), or null to output the
* "main" (or only) value.
*/
@@ -54,20 +59,20 @@ public abstract class Context implements EvaluationContext {
* context subclasses. This default implementation throws
* UnsupportedOperationException.</p>
*
- * @param index The index of the variable whose value to return.
- * @return The value of the indexed variable.
+ * @param index the index of the variable whose value to return.
+ * @return the value of the indexed variable.
*/
public Value get(int index) {
throw new UnsupportedOperationException(this + " does not support variable lookup by index");
}
/**
- * <p>Lookup by index rather than name directly to a double. This is supported by some optimized
+ * Lookup by index rather than name directly to a double. This is supported by some optimized
* context subclasses. This default implementation throws
- * UnsupportedOperationException.</p>
+ * UnsupportedOperationException.
*
- * @param index The index of the variable whose value to return.
- * @return The value of the indexed variable.
+ * @param index the index of the variable whose value to return.
+ * @return the value of the indexed variable.
*/
public double getDouble(int index) {
throw new UnsupportedOperationException(this + " does not support variable lookup by index");
@@ -81,24 +86,23 @@ public abstract class Context implements EvaluationContext {
}
/**
- * <p>Sets a value to this, or throws an UnsupportedOperationException if
- * this is not supported. This default implementation does the latter.</p> *
+ * Sets a value to this, or throws an UnsupportedOperationException if
+ * this is not supported. This default implementation does the latter.
*
- * @param name The name of the variable to set.
+ * @param name the name of the variable to set.
* @param value the value to set. Ownership of this value is transferred to this - if it is mutable
* (not frozen) it may be modified during execution
- * @since 5.1.5
*/
public void put(String name, Value value) {
throw new UnsupportedOperationException(this + " does not support variable assignment");
}
/**
- * <p>Returns all the names available in this, or throws an
+ * Returns all the names available in this, or throws an
* UnsupportedOperationException if this operation is not supported. This
- * default implementation does the latter.</p>
+ * default implementation does the latter.
*
- * @return The set of all variable names.
+ * @return the set of all variable names.
*/
public Set<String> names() {
throw new UnsupportedOperationException(this + " does not support return a list of its names");
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index ea750295423..2ef4a2ede2f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -3,6 +3,9 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* A value which acts as a double in numerical context.
@@ -16,6 +19,11 @@ public abstract class DoubleCompatibleValue extends Value {
public boolean hasDouble() { return true; }
@Override
+ public Tensor asTensor() {
+ return doubleAsTensor(asDouble());
+ }
+
+ @Override
public Value negate() { return new DoubleValue(-asDouble()); }
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index ac8aba6a617..dad69b31181 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -4,12 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* A string value.
*
* @author bratseth
- * @since 5.1.21
*/
public class StringValue extends Value {
@@ -35,6 +37,11 @@ public class StringValue extends Value {
}
@Override
+ public Tensor asTensor() {
+ return doubleAsTensor(asDouble());
+ }
+
+ @Override
public boolean hasDouble() { return true; }
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 49c3ccb7b01..26c30fe5ed2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -2,14 +2,10 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.google.common.annotations.Beta;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
-import com.yahoo.tensor.TensorType;
-
-import java.util.Collections;
-import java.util.Optional;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
/**
* A Value containing a tensor.
@@ -23,7 +19,7 @@ public class TensorValue extends Value {
/** The tensor value of this */
private final Tensor value;
-
+
public TensorValue(Tensor value) {
this.value = value;
}
@@ -131,7 +127,7 @@ public class TensorValue extends Value {
public Value compare(TruthOperator operator, Value argument) {
return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString())));
}
-
+
private Tensor compareTensor(TruthOperator operator, Tensor argument) {
switch (operator) {
case LARGER: return value.larger(argument);
@@ -152,7 +148,7 @@ public class TensorValue extends Value {
else
return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble())));
}
-
+
private Tensor functionOnTensor(Function function, Tensor argument) {
switch (function) {
case min: return value.min(argument);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index b2ccbe572d0..40d70e0022c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -5,6 +5,8 @@ import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* The result of a ranking expression evaluation.
@@ -25,6 +27,14 @@ public abstract class Value {
return new DoubleValue(asDouble());
}
+ /** Returns this as a tensor value */
+ public abstract Tensor asTensor();
+
+ /** A utility method for wrapping a sdouble in a rank 0 tensor */
+ protected Tensor doubleAsTensor(double value) {
+ return Tensor.Builder.of(TensorType.empty).cell(TensorAddress.of(), value).build();
+ }
+
/** Returns true if this value can return itself as a double, i.e asDoubleValue will return a value and not throw */
public abstract boolean hasDouble();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
new file mode 100644
index 00000000000..b4a9b363ade
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
@@ -0,0 +1,51 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * The result of importing a TensorFlow model into Vespa:
+ * - A list of ranking expressions reproducing the computations of the outputs in the TensorFlow model
+ * - A list of named constant tensors
+ * - A list of expected input tensors, with their tensor type
+ * - A list of warning messages
+ *
+ * @author bratseth
+ */
+// This object can be built incrementally within this package, but is immutable when observed from outside the package
+// TODO: Retain signature structure in ImportResult (input + output-expression bundles)
+public class ImportResult {
+
+ private final List<RankingExpression> expressions = new ArrayList<>();
+ private final Map<String, Tensor> constants = new HashMap<>();
+ private final Map<String, TensorType> arguments = new HashMap<>();
+ private final List<String> warnings = new ArrayList<>();
+
+ void add(RankingExpression expression) { expressions.add(expression); }
+ void set(String name, Tensor constant) { constants.put(name, constant); }
+ void set(String name, TensorType argument) { arguments.put(name, argument); }
+ void warn(String warning) { warnings.add(warning); }
+
+ /** Returns an immutable list of the expressions of this */
+ public List<RankingExpression> expressions() { return Collections.unmodifiableList(expressions); }
+
+ /** Returns an immutable map of the constants of this */
+ public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); }
+
+ /** Returns an immutable map of the arguments of this */
+ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
+
+ /** Returns an immutable list, in natural sort order of the warnings generated while importing this */
+ public List<String> warnings() {
+ return warnings.stream().sorted().collect(Collectors.toList());
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
new file mode 100644
index 00000000000..e7f7b5ef2f4
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -0,0 +1,160 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Matmul;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.Softmax;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
+
+/**
+ * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions.
+ *
+ * @author bratseth
+ */
+class OperationMapper {
+
+ /*
+ A note on conversion from implicitly numbered to explicitly named dimensions:
+ Vespa tensor dimensions are explicitly named and thus have an explicit notion of being
+ 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation
+ comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation
+ around dimension renaming operations which mirrors those built into the TF operation definitions.
+
+ To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost'
+ dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation
+ and the result is then renamed again (if necessary) to recover this convention across a full nested
+ computation.
+
+ This requires us to track tensor types throughout the conversion.
+ */
+
+ private TensorConverter tensorConverter = new TensorConverter();
+
+ TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) {
+ ensureArguments(2, arguments, "join");
+ TypedTensorFunction a = arguments.get(0);
+ TypedTensorFunction b = arguments.get(1);
+ if (a.type().rank() < b.type().rank())
+ throw new IllegalArgumentException("Attempt to join " + a.type() + " and " + b.type() + ", " +
+ "but this is not supported when the second argument has a higher rank");
+
+ TensorFunction bFunction = b.function();
+
+ if (a.type().rank() > b.type().rank()) {
+ // Well now we have entered the wonderful world of "broadcasting"
+ // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ // I'm not able to extract from that any unambiguous specification of which dimensions
+ // should be "stretched" when the tensor do not have the same number of dimensions.
+ // From trying this with TensorFlow it appears that the second tensor is matched to the
+ // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
+ // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).
+ List<String> renameFrom = new ArrayList<>();
+ List<String> renameTo = new ArrayList<>();
+ int sizeDifference = a.type().rank() - b.type().rank();
+ for (int i = 0; i < b.type().rank(); i++) {
+ renameFrom.add(b.type().dimensions().get(i).name());
+ renameTo.add("d" + (sizeDifference + i));
+ }
+ bFunction = new Rename(bFunction, renameFrom, renameTo);
+ }
+
+ Join function = new Join(a.function(), bFunction, doubleFunction);
+ return new TypedTensorFunction(a.type(), function); // output type is a type by TF definition and a.rank>=b.rank
+ }
+
+ TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) {
+ ensureArguments(1, arguments, "apply");
+ TypedTensorFunction a = arguments.get(0);
+
+ TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type());
+ com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction);
+ return new TypedTensorFunction(resultType, function);
+ }
+
+ TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) {
+ String name = tfNode.getName();
+ TensorType type = result.arguments().get(name);
+ if (type == null)
+ throw new IllegalArgumentException("An placeholder operation node is referencing input '" + name +
+ "', but there is no such input");
+ // Included literally in the expression and so must be produced by a separate macro in the rank profile
+ return new TypedTensorFunction(type, new VariableTensor(name));
+ }
+
+ TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) {
+ if ( ! tfNode.getName().endsWith("/read"))
+ throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " +
+ "nodes are only supported when reading variables");
+ if (tfNode.getInputList().size() != 1)
+ throw new IllegalArgumentException("A Variable/read node must have one input but has " +
+ tfNode.getInputList().size());
+
+ String name = tfNode.getInput(0);
+ AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
+ if (shapes == null)
+ throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape");
+ Session.Runner fetched = model.session().runner().fetch(name);
+ List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
+ if ( importedTensors.size() != 1)
+ throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " +
+ importedTensors.size());
+ Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0));
+ result.set(name, constant);
+ return new TypedTensorFunction(constant.type(),
+ new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
+ }
+
+ TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
+ ensureArguments(2, arguments, "matmul");
+ TypedTensorFunction a = arguments.get(0);
+ TypedTensorFunction b = arguments.get(1);
+ if (a.type().rank() < 2 || b.type().rank() < 2)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
+ if (a.type().rank() != b.type().rank())
+ throw new IllegalArgumentException("Tensors in matmul must have the same rank");
+
+ String afterLastDim = "d" + (a.type().rank() + 1);
+ // Let the first dimension of the second tensor be the same as the second dimension of the first
+ // and the second dimension of the second argument be not present in the first argument, while leaving the
+ // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication.
+
+ // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly
+
+ Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"),
+ ImmutableList.of("d1", afterLastDim));
+ Matmul matmul = new Matmul(a.function(), renamedB, "d1");
+ return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"),
+ new Rename(matmul, afterLastDim, "d1"));
+ }
+
+ TypedTensorFunction softmax(List<TypedTensorFunction> arguments) {
+ ensureArguments(1, arguments, "softmax");
+ TypedTensorFunction a = arguments.get(0);
+ // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1
+ String dimension = "d" + (a.type().rank() - 1);
+ Softmax softmax = new Softmax(a.function(), dimension);
+ return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax);
+ }
+
+ private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) {
+ if ( arguments.size() != count)
+ throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName +
+ ", but got " + arguments.size());
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
new file mode 100644
index 00000000000..df43225c333
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
@@ -0,0 +1,94 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+
+/**
+ * @author bratseth
+ */
+public class TensorConverter {
+
+ public Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
+ TensorType type = toVespaTensorType(tfTensor.shape());
+ Values values = readValuesOf(tfTensor);
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
+ for (int i = 0; i < values.size(); i++)
+ builder.cellByDirectIndex(i, values.get(i));
+ return builder.build();
+ }
+
+ private TensorType toVespaTensorType(long[] shape) {
+ TensorType.Builder b = new TensorType.Builder();
+ int dimensionIndex = 0;
+ for (long dimensionSize : shape) {
+ if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
+ b.indexed("d" + (dimensionIndex++), (int) dimensionSize);
+ }
+ return b.build();
+ }
+
+ private Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
+ switch (tfTensor.dataType()) {
+ case DOUBLE: return new DoubleValues(tfTensor);
+ case FLOAT: return new FloatValues(tfTensor);
+ // TODO: The rest
+ default:
+ throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
+ tfTensor.dataType() + " to a Vespa tensor");
+ }
+ }
+
+ /** Allows reading values from buffers of various numeric types as bytes */
+ private static abstract class Values {
+
+ private final int size;
+
+ protected Values(int size) {
+ this.size = size;
+ }
+
+ abstract double get(int i);
+
+ int size() { return size; }
+
+ }
+
+ private static class DoubleValues extends Values {
+
+ private final DoubleBuffer values;
+
+ DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = DoubleBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+
+ @Override
+ double get(int i) {
+ return values.get(i);
+ }
+
+ }
+
+ private static class FloatValues extends Values {
+
+ private final FloatBuffer values;
+
+ FloatValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = FloatBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+
+ @Override
+ double get(int i) {
+ return values.get(i);
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
new file mode 100644
index 00000000000..33523244129
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -0,0 +1,147 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.yolean.Exceptions;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.framework.GraphDef;
+import org.tensorflow.framework.MetaGraphDef;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.SignatureDef;
+import org.tensorflow.framework.TensorInfo;
+import org.tensorflow.framework.TensorShapeProto;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * Converts a saved TensorFlow model into a ranking expression and set of constants.
+ *
+ * @author bratseth
+ */
+public class TensorFlowImporter {
+
+ private final OperationMapper operationMapper = new OperationMapper();
+
+ /**
+ * Imports a saved TensorFlow model from a directory.
+ * The model should be saved as a pbtxt file.
+ * The name of the model is taken as the db/pbtxt file name (not including the file ending).
+ *
+ * @param modelDir the directory containing the TensorFlow model files to import
+ */
+ public ImportResult importModel(String modelDir) {
+ try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
+ return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model);
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e);
+ }
+ }
+
+ public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) {
+ try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
+ MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef());
+ SignatureDef signature = graph.getSignatureDefMap().get(inputSignatureName);
+ ImportResult result = new ImportResult();
+ importInputs(signature.getInputsMap(), result);
+ result.add(new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result)));
+ return result;
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e);
+ }
+ }
+
+ private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) {
+ ImportResult result = new ImportResult();
+ for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
+ importInputs(signatureEntry.getValue().getInputsMap(), result);
+ for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
+ try {
+ ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result);
+ result.add(new RankingExpression(output.getKey(), node));
+ }
+ catch (IllegalArgumentException e) {
+ result.warn("Skipping output '" + output.getValue().getName() + "' of signature '" +
+ signatureEntry.getValue().getMethodName() +
+ "': " + Exceptions.toMessageString(e));
+ }
+ }
+ }
+ return result;
+ }
+
+ private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult result) {
+ inputInfoMap.forEach((key, value) -> result.set(nameOf(value.getName()),
+ importTensorType(value.getTensorShape())));
+ }
+
+ private TensorType importTensorType(TensorShapeProto tensorShape) {
+ TensorType.Builder b = new TensorType.Builder();
+ for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) {
+ int dimensionSize = (int)dimension.getSize();
+ if (dimensionSize >= 0)
+ b.indexed("d" + b.rank(), dimensionSize);
+ else
+ b.indexed("d" + b.rank()); // unbound size
+ }
+ return b.build();
+ }
+
+ private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ return importNode(nameOf(output.getName()), graph, model, result);
+ }
+
+ private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function();
+ return new TensorFunctionNode(function); // wrap top level (only) as an expression
+ }
+
+ /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
+ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ return tensorFunctionOf(tfNode, graph, model, result);
+ }
+
+ private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
+ // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
+ switch (tfNode.getOp().toLowerCase()) {
+ case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add());
+ case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos());
+ case "placeholder" : return operationMapper.placeholder(tfNode, result);
+ case "identity" : return operationMapper.identity(tfNode, model, result);
+ case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result));
+ case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result));
+ default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
+ }
+ }
+
+ private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ return tfNode.getInputList().stream()
+ .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result))
+ .collect(Collectors.toList());
+ }
+
+ private NodeDef getNode(String name, GraphDef graph) {
+ return graph.getNodeList().stream()
+ .filter(node -> node.getName().equals(name))
+ .findFirst()
+ .orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'"));
+ }
+
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ private String nameOf(String name) {
+ return name.split(":")[0];
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
new file mode 100644
index 00000000000..5712da77700
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
@@ -0,0 +1,24 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+/**
+ * A tensor function returning a specific tensor type
+ *
+ * @author bratseth
+ */
+final class TypedTensorFunction {
+
+ private final TensorType type;
+ private final TensorFunction function;
+
+ public TypedTensorFunction(TensorType type, TensorFunction function) {
+ this.type = type;
+ this.function = function;
+ }
+
+ public TensorType type() { return type; }
+ public TensorFunction function() { return function; }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
index 71699b379b2..d366c9bfbe5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
@@ -14,23 +14,23 @@ import java.util.function.*;
/**
* A tensor generating function, whose arguments are determined by a tensor type
- *
+ *
* @author bratseth
*/
public class GeneratorLambdaFunctionNode extends CompositeNode {
private final TensorType type;
private final ExpressionNode generator;
-
+
public GeneratorLambdaFunctionNode(TensorType type, ExpressionNode generator) {
if ( ! type.dimensions().stream().allMatch(d -> d.size().isPresent()))
- throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " +
+ throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " +
"dimensions, but tried to generate " + type);
// TODO: Verify that the function only accesses the given arguments
this.type = type;
this.generator = generator;
}
-
+
@Override
public List<ExpressionNode> children() {
return Collections.singletonList(generator);
@@ -53,8 +53,8 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
public Value evaluate(Context context) {
return generator.evaluate(context);
}
-
- /**
+
+ /**
* Returns this as an operator which converts a list of integers into a double
*/
public IntegerListToDoubleLambda asIntegerListToDoubleOperator() {
@@ -70,7 +70,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
context.put(type.dimensions().get(i).name(), arguments.get(i));
return evaluate(context).asDouble();
}
-
+
@Override
public String toString() {
return GeneratorLambdaFunctionNode.this.toString();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
index 1f8db6e036c..ba765d07094 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
@@ -17,7 +17,7 @@ import java.util.Map;
* @author bratseth
*/
public class SerializationContext {
-
+
/** Expression functions indexed by name */
private final ImmutableMap<String, ExpressionFunction> functions;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index ce21e132980..8af3448ca6f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -21,22 +21,32 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
- @Beta
+@Beta
public class TensorFunctionNode extends CompositeNode {
private final TensorFunction function;
-
+
public TensorFunctionNode(TensorFunction function) {
this.function = function;
}
+ /** Returns the tensor function wrapped by this */
+ public TensorFunction function() { return function; }
+
@Override
public List<ExpressionNode> children() {
return function.functionArguments().stream()
- .map(f -> ((TensorFunctionExpressionNode)f).expression)
+ .map(this::toExpressionNode)
.collect(Collectors.toList());
}
+ private ExpressionNode toExpressionNode(TensorFunction f) {
+ if (f instanceof TensorFunctionExpressionNode)
+ return ((TensorFunctionExpressionNode)f).expression;
+ else
+ return new TensorFunctionNode(f);
+ }
+
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
List<TensorFunction> wrappedChildren = children.stream()
@@ -50,7 +60,7 @@ public class TensorFunctionNode extends CompositeNode {
// Serialize as primitive
return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this));
}
-
+
@Override
public Value evaluate(Context context) {
return new TensorValue(function.evaluate(context));
@@ -59,8 +69,8 @@ public class TensorFunctionNode extends CompositeNode {
public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
return new TensorFunctionExpressionNode(node);
}
-
- /**
+
+ /**
* A tensor function implemented by an expression.
* This allows us to pass expressions as tensor function arguments.
*/
@@ -68,13 +78,13 @@ public class TensorFunctionNode extends CompositeNode {
/** An expression which produces a tensor */
private final ExpressionNode expression;
-
+
public TensorFunctionExpressionNode(ExpressionNode expression) {
this.expression = expression;
}
-
+
@Override
- public List<TensorFunction> functionArguments() {
+ public List<TensorFunction> functionArguments() {
if (expression instanceof CompositeNode)
return ((CompositeNode)expression).children().stream()
.map(TensorFunctionExpressionNode::new)
@@ -108,7 +118,7 @@ public class TensorFunctionNode extends CompositeNode {
public String toString() {
return toString(ExpressionNodeToStringContext.empty);
}
-
+
@Override
public String toString(ToStringContext c) {
if (c instanceof ExpressionNodeToStringContext) {
@@ -121,14 +131,14 @@ public class TensorFunctionNode extends CompositeNode {
}
}
-
+
/** Allows passing serialization context arguments through TensorFunctions */
private static class ExpressionNodeToStringContext implements ToStringContext {
-
+
final SerializationContext context;
final Deque<String> path;
final CompositeNode parent;
-
+
public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null);
public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
new file mode 100644
index 00000000000..a1861a1c981
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
@@ -0,0 +1,89 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A very simple MNIST classifier.
+
+See extensive documentation at
+https://www.tensorflow.org/get_started/mnist/beginners
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+from tensorflow.examples.tutorials.mnist import input_data
+
+import tensorflow as tf
+
+FLAGS = None
+
+
+def main(_):
+ # Import data
+ mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
+
+ # Create the model
+ x = tf.placeholder(tf.float32, [None, 784])
+ W = tf.Variable(tf.zeros([784, 10]))
+ b = tf.Variable(tf.zeros([10]))
+ y = tf.matmul(x, W) + b
+
+ # Define loss and optimizer
+ y_ = tf.placeholder(tf.float32, [None, 10])
+
+ # The raw formulation of cross-entropy,
+ #
+ # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
+ # reduction_indices=[1]))
+ #
+ # can be numerically unstable.
+ #
+ # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
+ # outputs of 'y', and then average across the batch.
+ cross_entropy = tf.reduce_mean(
+ tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
+ train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
+
+ sess = tf.InteractiveSession()
+ tf.global_variables_initializer().run()
+ # Train
+ for _ in range(1000):
+ batch_xs, batch_ys = mnist.train.next_batch(100)
+ sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
+
+ # Test trained model
+ correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+ print(sess.run(accuracy, feed_dict={x: mnist.test.images,
+ y_: mnist.test.labels}))
+
+ # Save the model
+ export_path = "saved"
+ print('Exporting trained model to ', export_path)
+ builder = tf.saved_model.builder.SavedModelBuilder(export_path)
+ signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':x}, outputs = {'y':y})
+ builder.add_meta_graph_and_variables(sess,
+ [tf.saved_model.tag_constants.SERVING],
+ signature_def_map={'serving_default':signature})
+ builder.save(as_text=True)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
+ help='Directory for storing input data')
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt
new file mode 100644
index 00000000000..8100dfd594d
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt
@@ -0,0 +1,5039 @@
+saved_model_schema_version: 1
+meta_graphs {
+ meta_info_def {
+ stripped_op_list {
+ op {
+ name: "Add"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_STRING
+ }
+ }
+ }
+ }
+ op {
+ name: "ApplyGradientDescent"
+ input_arg {
+ name: "var"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "alpha"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "delta"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "out"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ }
+ op {
+ name: "ArgMax"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dimension"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "output_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "output_type"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Assign"
+ input_arg {
+ name: "ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output_ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "validate_shape"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ allows_uninitialized_input: true
+ }
+ op {
+ name: "BroadcastGradientArgs"
+ input_arg {
+ name: "s0"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "s1"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "r0"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "r1"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Cast"
+ input_arg {
+ name: "x"
+ type_attr: "SrcT"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "DstT"
+ }
+ attr {
+ name: "SrcT"
+ type: "type"
+ }
+ attr {
+ name: "DstT"
+ type: "type"
+ }
+ }
+ op {
+ name: "ConcatV2"
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ number_attr: "N"
+ }
+ input_arg {
+ name: "axis"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 2
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Const"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "value"
+ type: "tensor"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ }
+ op {
+ name: "Equal"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type: DT_BOOL
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_QUINT8
+ type: DT_QINT8
+ type: DT_QINT32
+ type: DT_STRING
+ type: DT_BOOL
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ is_commutative: true
+ }
+ op {
+ name: "ExpandDims"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dim"
+ type_attr: "Tdim"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tdim"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Fill"
+ input_arg {
+ name: "dims"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ }
+ op {
+ name: "FloorDiv"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Identity"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ }
+ op {
+ name: "MatMul"
+ input_arg {
+ name: "a"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "b"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "product"
+ type_attr: "T"
+ }
+ attr {
+ name: "transpose_a"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "transpose_b"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Maximum"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_commutative: true
+ }
+ op {
+ name: "Mean"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reduction_indices"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "keep_dims"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "MergeV2Checkpoints"
+ input_arg {
+ name: "checkpoint_prefixes"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "destination_prefix"
+ type: DT_STRING
+ }
+ attr {
+ name: "delete_old_dirs"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ is_stateful: true
+ }
+ op {
+ name: "Mul"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ is_commutative: true
+ }
+ op {
+ name: "NoOp"
+ }
+ op {
+ name: "Pack"
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ number_attr: "N"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "axis"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ }
+ op {
+ name: "Placeholder"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ default_value {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ op {
+ name: "Prod"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reduction_indices"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "keep_dims"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "RealDiv"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Reshape"
+ input_arg {
+ name: "tensor"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "shape"
+ type_attr: "Tshape"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tshape"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "RestoreV2"
+ input_arg {
+ name: "prefix"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "tensor_names"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "shape_and_slices"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "tensors"
+ type_list_attr: "dtypes"
+ }
+ attr {
+ name: "dtypes"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+ }
+ op {
+ name: "SaveV2"
+ input_arg {
+ name: "prefix"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "tensor_names"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "shape_and_slices"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "tensors"
+ type_list_attr: "dtypes"
+ }
+ attr {
+ name: "dtypes"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+ }
+ op {
+ name: "Shape"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "ShardedFilename"
+ input_arg {
+ name: "basename"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "shard"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "num_shards"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "filename"
+ type: DT_STRING
+ }
+ }
+ op {
+ name: "Slice"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "begin"
+ type_attr: "Index"
+ }
+ input_arg {
+ name: "size"
+ type_attr: "Index"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Index"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "SoftmaxCrossEntropyWithLogits"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "labels"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "loss"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprop"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ }
+ op {
+ name: "StringJoin"
+ input_arg {
+ name: "inputs"
+ type: DT_STRING
+ number_attr: "N"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "separator"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ }
+ op {
+ name: "Sub"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Sum"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reduction_indices"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "keep_dims"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Tile"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "multiples"
+ type_attr: "Tmultiples"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tmultiples"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "VariableV2"
+ output_arg {
+ name: "ref"
+ type_attr: "dtype"
+ is_ref: true
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+ }
+ op {
+ name: "ZerosLike"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ }
+ }
+ tags: "serve"
+ tensorflow_version: "1.4.1"
+ tensorflow_git_version: "v1.4.0-19-ga52c8d9b01"
+ }
+ graph_def {
+ node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "zeros"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ float_val: 0.0
+ }
+ }
+ }
+ }
+ node {
+ name: "Variable"
+ op: "VariableV2"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+ }
+ node {
+ name: "Variable/Assign"
+ op: "Assign"
+ input: "Variable"
+ input: "zeros"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "Variable/read"
+ op: "Identity"
+ input: "Variable"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "zeros_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 10
+ }
+ }
+ float_val: 0.0
+ }
+ }
+ }
+ }
+ node {
+ name: "Variable_1"
+ op: "VariableV2"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+ }
+ node {
+ name: "Variable_1/Assign"
+ op: "Assign"
+ input: "Variable_1"
+ input: "zeros_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "Variable_1/read"
+ op: "Identity"
+ input: "Variable_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "MatMul"
+ op: "MatMul"
+ input: "Placeholder"
+ input: "Variable/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "add"
+ op: "Add"
+ input: "MatMul"
+ input: "Variable_1/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Placeholder_1"
+ op: "Placeholder"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Rank"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 2
+ }
+ }
+ }
+ }
+ node {
+ name: "Shape"
+ op: "Shape"
+ input: "add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "Rank_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 2
+ }
+ }
+ }
+ }
+ node {
+ name: "Shape_1"
+ op: "Shape"
+ input: "add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "Sub/y"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Sub"
+ op: "Sub"
+ input: "Rank_1"
+ input: "Sub/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice/begin"
+ op: "Pack"
+ input: "Sub"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
+ }
+ node {
+ name: "Slice/size"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice"
+ op: "Slice"
+ input: "Shape_1"
+ input: "Slice/begin"
+ input: "Slice/size"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "concat/values_0"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: -1
+ }
+ }
+ }
+ }
+ node {
+ name: "concat/axis"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "concat"
+ op: "ConcatV2"
+ input: "concat/values_0"
+ input: "Slice"
+ input: "concat/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Reshape"
+ op: "Reshape"
+ input: "add"
+ input: "concat"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Rank_2"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 2
+ }
+ }
+ }
+ }
+ node {
+ name: "Shape_2"
+ op: "Shape"
+ input: "Placeholder_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "Sub_1/y"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Sub_1"
+ op: "Sub"
+ input: "Rank_2"
+ input: "Sub_1/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice_1/begin"
+ op: "Pack"
+ input: "Sub_1"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
+ }
+ node {
+ name: "Slice_1/size"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice_1"
+ op: "Slice"
+ input: "Shape_2"
+ input: "Slice_1/begin"
+ input: "Slice_1/size"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1/values_0"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: -1
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1/axis"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1"
+ op: "ConcatV2"
+ input: "concat_1/values_0"
+ input: "Slice_1"
+ input: "concat_1/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Reshape_1"
+ op: "Reshape"
+ input: "Placeholder_1"
+ input: "concat_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "SoftmaxCrossEntropyWithLogits"
+ op: "SoftmaxCrossEntropyWithLogits"
+ input: "Reshape"
+ input: "Reshape_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Sub_2/y"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Sub_2"
+ op: "Sub"
+ input: "Rank"
+ input: "Sub_2/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice_2/begin"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice_2/size"
+ op: "Pack"
+ input: "Sub_2"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
+ }
+ node {
+ name: "Slice_2"
+ op: "Slice"
+ input: "Shape"
+ input: "Slice_2/begin"
+ input: "Slice_2/size"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Reshape_2"
+ op: "Reshape"
+ input: "SoftmaxCrossEntropyWithLogits"
+ input: "Slice_2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "Mean"
+ op: "Mean"
+ input: "Reshape_2"
+ input: "Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/Shape"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Const"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Fill"
+ op: "Fill"
+ input: "gradients/Shape"
+ input: "gradients/Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Reshape/shape"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/Fill"
+ input: "gradients/Mean_grad/Reshape/shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Shape"
+ op: "Shape"
+ input: "Reshape_2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Tile"
+ op: "Tile"
+ input: "gradients/Mean_grad/Reshape"
+ input: "gradients/Mean_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tmultiples"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Shape_1"
+ op: "Shape"
+ input: "Reshape_2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Shape_2"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Const"
+ op: "Const"
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Prod"
+ op: "Prod"
+ input: "gradients/Mean_grad/Shape_1"
+ input: "gradients/Mean_grad/Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Const_1"
+ op: "Const"
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Prod_1"
+ op: "Prod"
+ input: "gradients/Mean_grad/Shape_2"
+ input: "gradients/Mean_grad/Const_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Maximum/y"
+ op: "Const"
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Maximum"
+ op: "Maximum"
+ input: "gradients/Mean_grad/Prod_1"
+ input: "gradients/Mean_grad/Maximum/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/floordiv"
+ op: "FloorDiv"
+ input: "gradients/Mean_grad/Prod"
+ input: "gradients/Mean_grad/Maximum"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Cast"
+ op: "Cast"
+ input: "gradients/Mean_grad/floordiv"
+ attr {
+ key: "DstT"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "SrcT"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/truediv"
+ op: "RealDiv"
+ input: "gradients/Mean_grad/Tile"
+ input: "gradients/Mean_grad/Cast"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Reshape_2_grad/Shape"
+ op: "Shape"
+ input: "SoftmaxCrossEntropyWithLogits"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/Reshape_2_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/Mean_grad/truediv"
+ input: "gradients/Reshape_2_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/zeros_like"
+ op: "ZerosLike"
+ input: "SoftmaxCrossEntropyWithLogits:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: -1
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims"
+ op: "ExpandDims"
+ input: "gradients/Reshape_2_grad/Reshape"
+ input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tdim"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul"
+ op: "Mul"
+ input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims"
+ input: "SoftmaxCrossEntropyWithLogits:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Reshape_grad/Shape"
+ op: "Shape"
+ input: "add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/Reshape_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul"
+ input: "gradients/Reshape_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Shape"
+ op: "Shape"
+ input: "MatMul"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Shape_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 10
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/BroadcastGradientArgs"
+ op: "BroadcastGradientArgs"
+ input: "gradients/add_grad/Shape"
+ input: "gradients/add_grad/Shape_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Sum"
+ op: "Sum"
+ input: "gradients/Reshape_grad/Reshape"
+ input: "gradients/add_grad/BroadcastGradientArgs"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/add_grad/Sum"
+ input: "gradients/add_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Sum_1"
+ op: "Sum"
+ input: "gradients/Reshape_grad/Reshape"
+ input: "gradients/add_grad/BroadcastGradientArgs:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Reshape_1"
+ op: "Reshape"
+ input: "gradients/add_grad/Sum_1"
+ input: "gradients/add_grad/Shape_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/tuple/group_deps"
+ op: "NoOp"
+ input: "^gradients/add_grad/Reshape"
+ input: "^gradients/add_grad/Reshape_1"
+ }
+ node {
+ name: "gradients/add_grad/tuple/control_dependency"
+ op: "Identity"
+ input: "gradients/add_grad/Reshape"
+ input: "^gradients/add_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/add_grad/Reshape"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/tuple/control_dependency_1"
+ op: "Identity"
+ input: "gradients/add_grad/Reshape_1"
+ input: "^gradients/add_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/add_grad/Reshape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/MatMul_grad/MatMul"
+ op: "MatMul"
+ input: "gradients/add_grad/tuple/control_dependency"
+ input: "Variable/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "gradients/MatMul_grad/MatMul_1"
+ op: "MatMul"
+ input: "Placeholder"
+ input: "gradients/add_grad/tuple/control_dependency"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/MatMul_grad/tuple/group_deps"
+ op: "NoOp"
+ input: "^gradients/MatMul_grad/MatMul"
+ input: "^gradients/MatMul_grad/MatMul_1"
+ }
+ node {
+ name: "gradients/MatMul_grad/tuple/control_dependency"
+ op: "Identity"
+ input: "gradients/MatMul_grad/MatMul"
+ input: "^gradients/MatMul_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/MatMul_grad/MatMul"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/MatMul_grad/tuple/control_dependency_1"
+ op: "Identity"
+ input: "gradients/MatMul_grad/MatMul_1"
+ input: "^gradients/MatMul_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/MatMul_grad/MatMul_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "GradientDescent/learning_rate"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.5
+ }
+ }
+ }
+ }
+ node {
+ name: "GradientDescent/update_Variable/ApplyGradientDescent"
+ op: "ApplyGradientDescent"
+ input: "Variable"
+ input: "GradientDescent/learning_rate"
+ input: "gradients/MatMul_grad/tuple/control_dependency_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "GradientDescent/update_Variable_1/ApplyGradientDescent"
+ op: "ApplyGradientDescent"
+ input: "Variable_1"
+ input: "GradientDescent/learning_rate"
+ input: "gradients/add_grad/tuple/control_dependency_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "GradientDescent"
+ op: "NoOp"
+ input: "^GradientDescent/update_Variable/ApplyGradientDescent"
+ input: "^GradientDescent/update_Variable_1/ApplyGradientDescent"
+ }
+ node {
+ name: "init"
+ op: "NoOp"
+ input: "^Variable/Assign"
+ input: "^Variable_1/Assign"
+ }
+ node {
+ name: "ArgMax/dimension"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "ArgMax"
+ op: "ArgMax"
+ input: "add"
+ input: "ArgMax/dimension"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_type"
+ value {
+ type: DT_INT64
+ }
+ }
+ }
+ node {
+ name: "ArgMax_1/dimension"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "ArgMax_1"
+ op: "ArgMax"
+ input: "Placeholder_1"
+ input: "ArgMax_1/dimension"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_type"
+ value {
+ type: DT_INT64
+ }
+ }
+ }
+ node {
+ name: "Equal"
+ op: "Equal"
+ input: "ArgMax"
+ input: "ArgMax_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Cast_1"
+ op: "Cast"
+ input: "Equal"
+ attr {
+ key: "DstT"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "SrcT"
+ value {
+ type: DT_BOOL
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "Mean_1"
+ op: "Mean"
+ input: "Cast_1"
+ input: "Const_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "save/Const"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ }
+ string_val: "model"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/StringJoin/inputs_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ }
+ string_val: "_temp_6ca9fa5171ed4237a2fbcc27277e2864/part"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/StringJoin"
+ op: "StringJoin"
+ input: "save/Const"
+ input: "save/StringJoin/inputs_1"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "separator"
+ value {
+ s: ""
+ }
+ }
+ }
+ node {
+ name: "save/num_shards"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "save/ShardedFilename/shard"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "save/ShardedFilename"
+ op: "ShardedFilename"
+ input: "save/StringJoin"
+ input: "save/ShardedFilename/shard"
+ input: "save/num_shards"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "save/SaveV2/tensor_names"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ string_val: "Variable"
+ string_val: "Variable_1"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/SaveV2/shape_and_slices"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ string_val: ""
+ string_val: ""
+ }
+ }
+ }
+ }
+ node {
+ name: "save/SaveV2"
+ op: "SaveV2"
+ input: "save/ShardedFilename"
+ input: "save/SaveV2/tensor_names"
+ input: "save/SaveV2/shape_and_slices"
+ input: "Variable"
+ input: "Variable_1"
+ attr {
+ key: "dtypes"
+ value {
+ list {
+ type: DT_FLOAT
+ type: DT_FLOAT
+ }
+ }
+ }
+ }
+ node {
+ name: "save/control_dependency"
+ op: "Identity"
+ input: "save/ShardedFilename"
+ input: "^save/SaveV2"
+ attr {
+ key: "T"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@save/ShardedFilename"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "save/MergeV2Checkpoints/checkpoint_prefixes"
+ op: "Pack"
+ input: "save/ShardedFilename"
+ input: "^save/control_dependency"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
+ }
+ node {
+ name: "save/MergeV2Checkpoints"
+ op: "MergeV2Checkpoints"
+ input: "save/MergeV2Checkpoints/checkpoint_prefixes"
+ input: "save/Const"
+ attr {
+ key: "delete_old_dirs"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "save/Identity"
+ op: "Identity"
+ input: "save/Const"
+ input: "^save/control_dependency"
+ input: "^save/MergeV2Checkpoints"
+ attr {
+ key: "T"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2/tensor_names"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ string_val: "Variable"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2/shape_and_slices"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ string_val: ""
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2"
+ op: "RestoreV2"
+ input: "save/Const"
+ input: "save/RestoreV2/tensor_names"
+ input: "save/RestoreV2/shape_and_slices"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtypes"
+ value {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ }
+ node {
+ name: "save/Assign"
+ op: "Assign"
+ input: "Variable"
+ input: "save/RestoreV2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2_1/tensor_names"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ string_val: "Variable_1"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2_1/shape_and_slices"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ string_val: ""
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2_1"
+ op: "RestoreV2"
+ input: "save/Const"
+ input: "save/RestoreV2_1/tensor_names"
+ input: "save/RestoreV2_1/shape_and_slices"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtypes"
+ value {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ }
+ node {
+ name: "save/Assign_1"
+ op: "Assign"
+ input: "Variable_1"
+ input: "save/RestoreV2_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "save/restore_shard"
+ op: "NoOp"
+ input: "^save/Assign"
+ input: "^save/Assign_1"
+ }
+ node {
+ name: "save/restore_all"
+ op: "NoOp"
+ input: "^save/restore_shard"
+ }
+ versions {
+ producer: 24
+ }
+ }
+ saver_def {
+ filename_tensor_name: "save/Const:0"
+ save_tensor_name: "save/Identity:0"
+ restore_op_name: "save/restore_all"
+ max_to_keep: 5
+ sharded: true
+ keep_checkpoint_every_n_hours: 10000.0
+ version: V2
+ }
+ collection_def {
+ key: "train_op"
+ value {
+ node_list {
+ value: "GradientDescent"
+ }
+ }
+ }
+ collection_def {
+ key: "trainable_variables"
+ value {
+ bytes_list {
+ value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0"
+ value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0"
+ }
+ }
+ }
+ collection_def {
+ key: "variables"
+ value {
+ bytes_list {
+ value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0"
+ value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0"
+ }
+ }
+ }
+ signature_def {
+ key: "serving_default"
+ value {
+ inputs {
+ key: "x"
+ value {
+ name: "Placeholder:0"
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ outputs {
+ key: "y"
+ value {
+ name: "add:0"
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ method_name: "tensorflow/serving/predict"
+ }
+ }
+}
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
new file mode 100644
index 00000000000..8474aa0a04c
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index
new file mode 100644
index 00000000000..cfcdac20409
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index
Binary files differ
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 82e5d0cfe5b..3aa2d144f1f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.rule.*;
-import com.yahoo.tensor.Tensor;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import org.junit.Test;
+
import static org.junit.Assert.assertEquals;
/**
@@ -83,7 +88,7 @@ public class EvaluationTestCase {
tester.assertEvaluates(0, "sin(0)");
tester.assertEvaluates(1, "cos(0)");
tester.assertEvaluates(8, "pow(4/2,min(cos(0)*3,5))");
-
+
// Random feature (which is also a tensor function) (We expect to be able to parse it and look up a zero)
tester.assertEvaluates(0, "random(1)");
tester.assertEvaluates(0, "random(foo)");
@@ -152,7 +157,7 @@ public class EvaluationTestCase {
"tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
"!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }");
-
+
// -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence)
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }");
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
index ee2b1c147e3..ba0db4de5e1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
@@ -34,7 +34,7 @@ public class EvaluationTester {
}
// TODO: Test both bound and unbound indexed
- public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors,
+ public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors,
String ... tensorArgumentStrings) {
MapContext context = defaultContext.thawedCopy();
int argumentIndex = 0;
@@ -46,7 +46,7 @@ public class EvaluationTester {
argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString);
context.put("tensor" + (argumentIndex++), new TensorValue(argument));
}
- return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context,
+ return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context,
mappedTensors ? "Mapped tensors" : "Indexed tensors");
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
new file mode 100644
index 00000000000..863479f3531
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
@@ -0,0 +1,114 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+
+import java.nio.FloatBuffer;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * @author bratseth
+ */
+public class Mnist_SoftmaxTestCase {
+
+ @Test
+ public void testImporting() {
+ String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
+ ImportResult result = new TensorFlowImporter().importModel(modelDir);
+
+ // Check logged messages
+ result.warnings().forEach(System.err::println);
+ assertEquals(0, result.warnings().size());
+
+ // Check arguments
+ assertEquals(1, result.arguments().size());
+ TensorType argument0 = result.arguments().get("Placeholder");
+ assertNotNull(argument0);
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
+
+ // Check constants
+ assertEquals(2, result.constants().size());
+
+ Tensor constant0 = result.constants().get("Variable");
+ assertNotNull(constant0);
+ assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
+ constant0.type());
+ assertEquals(7840, constant0.size());
+
+ Tensor constant1 = result.constants().get("Variable_1");
+ assertNotNull(constant1);
+ assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
+ constant1.type());
+ assertEquals(10, constant1.size());
+
+ // Check resulting Vespa expression
+ assertEquals(1, result.expressions().size());
+ assertEquals("y", result.expressions().get(0).getName());
+ assertEquals("" +
+ "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
+ "rename(constant(Variable_1), d0, d1), " +
+ "f(a,b)(a + b))",
+ toNonPrimitiveString(result.expressions().get(0)));
+
+ // Test execution
+ String signatureName = "serving_default";
+
+ assertEqualResult(modelDir, signatureName, "Variable/read");
+ assertEqualResult(modelDir, signatureName, "Variable_1/read");
+ // TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder");
+ assertEqualResult(modelDir, signatureName, "MatMul");
+ assertEqualResult(modelDir, signatureName, "add");
+ }
+
+ private void assertEqualResult(String modelDir, String signatureName, String operationName) {
+ ImportResult result = new TensorFlowImporter().importNode(modelDir, signatureName, operationName);
+
+ Tensor tfResult = tensorFlowExecute(modelDir, operationName);
+ Context context = contextFrom(result);
+ Tensor placeholder = placeholderArgument();
+ context.put("Placeholder", new TensorValue(placeholder));
+ Tensor vespaResult = result.expressions().get(0).evaluate(context).asTensor();
+ assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
+ }
+
+ private Tensor tensorFlowExecute(String modelDir, String operationName) {
+ SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
+ Session.Runner runner = model.session().runner();
+ org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784));
+ runner.feed("Placeholder", placeholder);
+ List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
+ assertEquals(1, results.size());
+ return new TensorConverter().toVespaTensor(results.get(0));
+ }
+
+ private Context contextFrom(ImportResult result) {
+ MapContext context = new MapContext();
+ result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ return context;
+ }
+
+ private String toNonPrimitiveString(RankingExpression expression) {
+ // toString on the wrapping expression will map to primitives, which is harder to read
+ return ((TensorFunctionNode)expression.getRoot()).function().toString();
+ }
+
+ private Tensor placeholderArgument() {
+ int size = 784;
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build());
+ for (int i = 0; i < size; i++)
+ b.cell(0, 0, i);
+ return b.build();
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
index dde9d4bf21e..1960c1fe876 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
@@ -59,7 +59,7 @@ public class TensorConformanceTest {
try {
ObjectMapper mapper = new ObjectMapper();
JsonNode node = mapper.readTree(test);
-
+
if (node.has("num_tests")) {
Assert.assertEquals(node.get("num_tests").asInt(), count);
return true;
@@ -67,7 +67,7 @@ public class TensorConformanceTest {
if (!node.has("expression")) {
return true; // ignore
}
-
+
String expression = node.get("expression").asText();
MapContext context = getInput(node.get("inputs"));
Tensor expect = getTensor(node.get("result").get("expect").asText());
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java
index 7874dcb24ab..16a541f939c 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java
@@ -5,6 +5,7 @@ import com.google.common.annotations.Beta;
import com.yahoo.vespa.http.client.Result;
import com.yahoo.vespa.http.client.config.Endpoint;
import com.yahoo.vespa.http.client.core.Document;
+import com.yahoo.vespa.http.client.core.Exceptions;
import com.yahoo.vespa.http.client.core.operationProcessor.EndPointResultFactory;
import com.yahoo.vespa.http.client.core.EndpointResult;
import com.yahoo.vespa.http.client.core.ServerResponseException;
@@ -318,29 +319,28 @@ class IOThread implements Runnable, AutoCloseable {
successfullHandshakes.getAndIncrement();
} catch (ServerResponseException ser) {
executeProblemsCounter.incrementAndGet();
- log.log(Level.INFO, "Handshake did not work out " + endpoint, ser.getMessage());
+ log.log(Level.INFO, "Handshake did not work out " + endpoint, Exceptions.toMessageString(ser));
drainFirstDocumentsInQueueIfOld();
return ThreadState.CONNECTED;
} catch (Throwable throwable) { // This cover IOException as well
executeProblemsCounter.incrementAndGet();
- log.log(Level.INFO, "Problem with Handshake " + endpoint, throwable.getMessage());
+ log.log(Level.INFO, "Problem with Handshake " + endpoint, Exceptions.toMessageString(throwable));
drainFirstDocumentsInQueueIfOld();
client.close();
return ThreadState.DISCONNECTED;
}
return ThreadState.SESSION_SYNCED;
case SESSION_SYNCED:
- final int maxWaitTimeMilliSecs = 100;
try {
- ProcessResponse processResponse = pullAndProcessData(maxWaitTimeMilliSecs);
+ ProcessResponse processResponse = pullAndProcessData(100);
gatewayThrottler.handleCall(processResponse.transitiveErrorCount);
}
catch (ServerResponseException ser) {
- log.info("Problems while handing data over to gateway " + endpoint + " " + ser.getMessage());
+ log.info("Problems while handing data over to gateway " + endpoint + ": " + Exceptions.toMessageString(ser));
return ThreadState.CONNECTED;
}
catch (Throwable e) { // Covers IOException as well
- log.info("Problems while handing data over to gateway " + endpoint + " " + e.getMessage());
+ log.info("Problems while handing data over to gateway " + endpoint + ": " + Exceptions.toMessageString(e));
client.close();
return ThreadState.DISCONNECTED;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index 00e106dd035..f6237a1977a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -7,7 +7,7 @@ import java.util.Arrays;
/**
* The sizes of a set of dimensions.
- *
+ *
* @author bratseth
*/
@Beta
@@ -48,7 +48,7 @@ public final class DimensionSizes {
@Override
public int hashCode() { return Arrays.hashCode(sizes); }
- /**
+ /**
* Builder of a set of dimension sizes.
* Dimensions whose size is not set before building will get size 0.
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index c207dabca3a..6b0d769de9f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -25,12 +25,12 @@ public class IndexedTensor implements Tensor {
/** The prescribed and possibly abstract type this is an instance of */
private final TensorType type;
-
+
/** The sizes of the dimensions of this in the order of the dimensions of the type */
private final DimensionSizes dimensionSizes;
-
+
private final double[] values;
-
+
private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) {
this.type = type;
this.dimensionSizes = dimensionSizes;
@@ -43,8 +43,8 @@ public class IndexedTensor implements Tensor {
}
/**
- * Returns an iterator over the cells of this.
- * Cells are returned in order of increasing indexes in each dimension, increasing
+ * Returns an iterator over the cells of this.
+ * Cells are returned in order of increasing indexes in each dimension, increasing
* indexes of later dimensions in the dimension type before earlier.
*/
@Override
@@ -69,7 +69,7 @@ public class IndexedTensor implements Tensor {
/**
* Returns an iterator over the values of this.
- * Values are returned in order of increasing indexes in each dimension, increasing
+ * Values are returned in order of increasing indexes in each dimension, increasing
* indexes of later dimensions in the dimension type before earlier.
*/
@Override
@@ -81,7 +81,7 @@ public class IndexedTensor implements Tensor {
* Returns an iterator over value iterators where the outer iterator is over each unique value of the dimensions
* given and the inner iterator is over each unique value of the rest of the dimensions, in the same order as
* other iterator.
- *
+ *
* @param dimensions the names of the dimensions of the superspace
* @param sizes the size of each dimension in the space we are returning values for, containing
* one value per dimension of this tensor (in order). Each size may be the same or smaller
@@ -96,9 +96,9 @@ public class IndexedTensor implements Tensor {
return subspaceIterator(dimensions, dimensionSizes);
}
- /**
+ /**
* Returns the value at the given indexes
- *
+ *
* @param indexes the indexes into the dimensions of this. Must be one number per dimension of this
* @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
*/
@@ -119,7 +119,7 @@ public class IndexedTensor implements Tensor {
}
private double get(int valueIndex) { return values[valueIndex]; }
-
+
private static int toValueIndex(int[] indexes, DimensionSizes sizes) {
if (indexes.length == 1) return indexes[0]; // for speed
if (indexes.length == 0) return 0; // for speed
@@ -165,7 +165,7 @@ public class IndexedTensor implements Tensor {
public Map<TensorAddress, Double> cells() {
if (dimensionSizes.dimensions() == 0)
return Collections.singletonMap(TensorAddress.of(), values[0]);
-
+
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
for (int i = 0; i < values.length; i++) {
@@ -174,13 +174,13 @@ public class IndexedTensor implements Tensor {
}
return builder.build();
}
-
+
@Override
public int hashCode() { return Arrays.hashCode(values); }
@Override
public String toString() { return Tensor.toStandardString(this); }
-
+
@Override
public boolean equals(Object other) {
if ( ! ( other instanceof Tensor)) return false;
@@ -188,9 +188,9 @@ public class IndexedTensor implements Tensor {
}
public abstract static class Builder implements Tensor.Builder {
-
+
final TensorType type;
-
+
private Builder(TensorType type) {
this.type = type;
}
@@ -202,7 +202,7 @@ public class IndexedTensor implements Tensor {
return new UnboundBuilder(type);
}
- /**
+ /**
* Create a builder with dimension size information for this instance. Must be one size entry per dimension,
* and, agree with the type size information when specified in the type.
* If sizes are completely specified in the type this size information is redundant.
@@ -210,16 +210,16 @@ public class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes) {
// validate
if (sizes.dimensions() != type.dimensions().size())
- throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " +
+ throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " +
"for " + type);
for (int i = 0; i < sizes.dimensions(); i++ ) {
Optional<Integer> size = type.dimensions().get(i).size();
if (size.isPresent() && size.get() < sizes.size(i))
- throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " +
+ throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " +
sizes.size(i) +
" but cannot be larger than " + size.get() + " in " + type);
}
-
+
return new BoundBuilder(type, sizes);
}
@@ -232,7 +232,7 @@ public class IndexedTensor implements Tensor {
public abstract IndexedTensor build();
}
-
+
/** A bound builder can create the double array directly */
public static class BoundBuilder extends Builder {
@@ -257,13 +257,13 @@ public class IndexedTensor implements Tensor {
this.sizes = sizes;
values = new double[sizes.totalSize()];
}
-
+
@Override
public BoundBuilder cell(double value, int ... indexes) {
values[toValueIndex(indexes, sizes)] = value;
return this;
}
-
+
@Override
public CellBuilder cell() {
return new CellBuilder(type, this);
@@ -294,8 +294,8 @@ public class IndexedTensor implements Tensor {
return this;
}
- /**
- * Set a cell value by the index in the internal layout of this cell.
+ /**
+ * Set a cell value by the index in the internal layout of this cell.
* This requires knowledge of the internal layout of cells in this implementation, and should therefore
* probably not be used (but when it can be used it is fast).
*/
@@ -330,7 +330,7 @@ public class IndexedTensor implements Tensor {
fillValues(0, 0, firstDimension, dimensionSizes, values);
return new IndexedTensor(type, dimensionSizes, values);
}
-
+
private DimensionSizes findDimensionSizes(List<Object> firstDimension) {
List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size());
findDimensionSizes(0, dimensionSizeList, firstDimension);
@@ -347,16 +347,16 @@ public class IndexedTensor implements Tensor {
if (currentDimensionIndex == dimensionSizes.size())
dimensionSizes.add(currentDimension.size());
else if (dimensionSizes.get(currentDimensionIndex) != currentDimension.size())
- throw new IllegalArgumentException("Missing values in dimension " +
+ throw new IllegalArgumentException("Missing values in dimension " +
type.dimensions().get(currentDimensionIndex) + " in " + type);
-
+
for (Object value : currentDimension)
if (value instanceof List)
findDimensionSizes(currentDimensionIndex + 1, dimensionSizes, (List<Object>)value);
}
@SuppressWarnings("unchecked")
- private void fillValues(int currentDimensionIndex, int offset, List<Object> currentDimension,
+ private void fillValues(int currentDimensionIndex, int offset, List<Object> currentDimension,
DimensionSizes sizes, double[] values) {
if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension
for (int i = 0; i < currentDimension.size(); i++)
@@ -369,7 +369,7 @@ public class IndexedTensor implements Tensor {
}
}
}
-
+
private double nullAsZero(Double value) {
if (value == null) return 0;
return value;
@@ -431,7 +431,7 @@ public class IndexedTensor implements Tensor {
}
}
-
+
private final class CellIterator implements Iterator<Cell> {
private int count = 0;
@@ -451,7 +451,7 @@ public class IndexedTensor implements Tensor {
reusedCell.value = get(indexes.toSourceValueIndex());
return reusedCell;
}
-
+
}
private final class ValueIterator implements Iterator<Double> {
@@ -474,25 +474,25 @@ public class IndexedTensor implements Tensor {
}
}
-
+
private final class SuperspaceIterator implements Iterator<SubspaceIterator> {
private final Indexes superindexes;
/** Those indexes this should iterate over */
private final List<Integer> subdimensionIndexes;
-
- /**
+
+ /**
* The sizes of the space we'll return values of, one value for each dimension of this tensor,
- * which may be equal to or smaller than the sizes of this tensor
+ * which may be equal to or smaller than the sizes of this tensor
*/
private final DimensionSizes iterateSizes;
private int count = 0;
-
+
private SuperspaceIterator(Set<String> superdimensionNames, DimensionSizes iterateSizes) {
this.iterateSizes = iterateSizes;
-
+
List<Integer> superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator
subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length)
for (int i = type.dimensions().size() - 1; i >= 0; i-- ) { // iterate inner dimensions first
@@ -501,10 +501,10 @@ public class IndexedTensor implements Tensor {
else
subdimensionIndexes.add(i);
}
-
+
superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, superdimensionIndexes);
}
-
+
@Override
public boolean hasNext() {
return count < superindexes.size();
@@ -527,7 +527,7 @@ public class IndexedTensor implements Tensor {
*/
public final class SubspaceIterator implements Iterator<Tensor.Cell> {
- /**
+ /**
* This iterator will iterate over the given dimensions, in the order given
* (the first dimension index given is incremented to exhaustion first (i.e is etc.).
* This may be any subset of the dimensions given by address and dimensionSizes.
@@ -538,21 +538,21 @@ public class IndexedTensor implements Tensor {
private Indexes indexes;
private int count = 0;
-
+
/** A lazy cell for reuse */
private final LazyCell reusedCell;
-
- /**
+
+ /**
* Creates a new subspace iterator
- *
+ *
* @param iterateDimensions the dimensions to iterate over, given as indexes in the dimension order of the
* type of the tensor this iterates over. This iterator will iterate over these
- * dimensions to exhaustion in the order given (the first dimension index given is
+ * dimensions to exhaustion in the order given (the first dimension index given is
* incremented to exhaustion first (i.e is etc.), while other dimensions will be held
* at a constant position.
* This may be any subset of the dimensions given by address and dimensionSizes.
* This is treated as immutable.
- * @param address the address of the first cell of this subspace.
+ * @param address the address of the first cell of this subspace.
*/
private SubspaceIterator(List<Integer> iterateDimensions, int[] address, DimensionSizes iterateSizes) {
this.iterateDimensions = iterateDimensions;
@@ -561,26 +561,26 @@ public class IndexedTensor implements Tensor {
this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address);
reusedCell = new LazyCell(indexes, Double.NaN);
}
-
+
/** Returns the total number of cells in this subspace */
- public int size() {
+ public int size() {
return indexes.size();
}
-
+
/** Returns the address of the cell this currently points to (which may be an invalid position) */
public TensorAddress address() { return indexes.toAddress(); }
-
+
/** Rewind this iterator to the first element */
- public void reset() {
+ public void reset() {
this.count = 0;
- this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address);
+ this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address);
}
-
+
@Override
public boolean hasNext() {
- return count < indexes.size();
+ return count < indexes.size();
}
-
+
/** Returns the next cell, which is valid until next() is called again */
@Override
public Cell next() {
@@ -611,15 +611,15 @@ public class IndexedTensor implements Tensor {
public TensorAddress getKey() {
return indexes.toAddress();
}
-
+
@Override
public Double getValue() { return value; }
}
// TODO: Make dimensionSizes a class
-
- /**
+
+ /**
* An array of indexes into this tensor which are able to find the next index in the value order.
* next() can be called once per element in the dimensions we iterate over. It must be called once
* before accessing the first position.
@@ -631,7 +631,7 @@ public class IndexedTensor implements Tensor {
private final DimensionSizes iterationSizes;
protected final int[] indexes;
-
+
public static Indexes of(DimensionSizes sizes) {
return of(sizes, sizes);
}
@@ -676,14 +676,14 @@ public class IndexedTensor implements Tensor {
return new MultiDimensionIndexes(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, size);
}
}
-
+
private static List<Integer> completeIterationOrder(int length) {
List<Integer> iterationDimensions = new ArrayList<>(length);
for (int i = 0; i < length; i++)
iterationDimensions.add(length - 1 - i);
return iterationDimensions;
}
-
+
private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, int[] indexes) {
this.sourceSizes = sourceSizes;
this.iterationSizes = iterationSizes;
@@ -708,9 +708,9 @@ public class IndexedTensor implements Tensor {
/** Returns a copy of the indexes of this which must not be modified */
public int[] indexesForReading() { return indexes; }
-
- int toSourceValueIndex() {
- return IndexedTensor.toValueIndex(indexes, sourceSizes);
+
+ int toSourceValueIndex() {
+ return IndexedTensor.toValueIndex(indexes, sourceSizes);
}
int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); }
@@ -729,9 +729,9 @@ public class IndexedTensor implements Tensor {
public String toString() {
return "indexes " + Arrays.toString(indexes);
}
-
+
public abstract int size();
-
+
public abstract void next();
}
@@ -763,18 +763,18 @@ public class IndexedTensor implements Tensor {
public void next() {}
}
-
+
private static class MultiDimensionIndexes extends Indexes {
private final int size;
private final List<Integer> iterateDimensions;
-
+
private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimensions = iterateDimensions;
this.size = size;
-
+
// Initialize to the (virtual) position before the first cell
indexes[iterateDimensions.get(0)]--;
}
@@ -785,10 +785,10 @@ public class IndexedTensor implements Tensor {
return size;
}
- /**
- * Advances this to the next cell in the standard indexed tensor cell order.
- * The first call to this will put it at the first position.
- *
+ /**
+ * Advances this to the next cell in the standard indexed tensor cell order.
+ * The first call to this will put it at the first position.
+ *
* @throws RuntimeException if this is called more times than its size
*/
@Override
@@ -802,12 +802,12 @@ public class IndexedTensor implements Tensor {
}
}
-
+
/** In this case we can reuse the source index computation for the iteration index */
private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes {
private int lastComputedSourceValueIndex = -1;
-
+
private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
super(sizes, sizes, iterateDimensions, initialIndexes, size);
}
@@ -827,7 +827,7 @@ public class IndexedTensor implements Tensor {
private final int size;
private final int iterateDimension;
-
+
/** Maintain this directly as an optimization for 1-d iteration */
private int currentSourceValueIndex, currentIterationValueIndex;
@@ -847,7 +847,7 @@ public class IndexedTensor implements Tensor {
currentSourceValueIndex = IndexedTensor.toValueIndex(indexes, sourceSizes);
currentIterationValueIndex = IndexedTensor.toValueIndex(indexes, iterateSizes);
}
-
+
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@Override
public int size() {
@@ -855,8 +855,8 @@ public class IndexedTensor implements Tensor {
}
/**
- * Advances this to the next cell in the standard indexed tensor cell order.
- * The first call to this will put it at the first position.
+ * Advances this to the next cell in the standard indexed tensor cell order.
+ * The first call to this will put it at the first position.
*
* @throws RuntimeException if this is called more times than its size
*/
@@ -888,7 +888,7 @@ public class IndexedTensor implements Tensor {
/** The iteration step in the value index space */
private final int step;
- private EqualSizeSingleDimensionIndexes(DimensionSizes sizes,
+ private EqualSizeSingleDimensionIndexes(DimensionSizes sizes,
int iterateDimension, int[] initialIndexes, int size) {
super(sizes, sizes, initialIndexes);
this.iterateDimension = iterateDimension;
@@ -907,8 +907,8 @@ public class IndexedTensor implements Tensor {
}
/**
- * Advances this to the next cell in the standard indexed tensor cell order.
- * The first call to this will put it at the first position.
+ * Advances this to the next cell in the standard indexed tensor cell order.
+ * The first call to this will put it at the first position.
*
* @throws RuntimeException if this is called more times than its size
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index 618bff0caae..aba61478e69 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -27,7 +27,7 @@ public class MappedTensor implements Tensor {
@Override
public TensorType type() { return type; }
-
+
@Override
public int size() { return cells.size(); }
@@ -56,16 +56,16 @@ public class MappedTensor implements Tensor {
}
public static class Builder implements Tensor.Builder {
-
+
private final TensorType type;
private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
-
+
public static Builder of(TensorType type) { return new Builder(type); }
private Builder(TensorType type) {
this.type = type;
}
-
+
public CellBuilder cell() {
return new CellBuilder(type, this);
}
@@ -89,24 +89,24 @@ public class MappedTensor implements Tensor {
public MappedTensor build() {
return new MappedTensor(type, cells.build());
}
-
+
}
private static class CellIteratorAdaptor implements Iterator<Cell> {
private final Iterator<Map.Entry<TensorAddress, Double>> adaptedIterator;
-
+
private CellIteratorAdaptor(Iterator<Map.Entry<TensorAddress, Double>> adaptedIterator) {
this.adaptedIterator = adaptedIterator;
}
-
+
@Override
public boolean hasNext() { return adaptedIterator.hasNext(); }
@Override
public Cell next() {
Map.Entry<TensorAddress, Double> entry = adaptedIterator.next();
- return new Cell(entry.getKey(), entry.getValue());
+ return new Cell(entry.getKey(), entry.getValue());
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 79bb27fcd1b..9a751e078e0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -117,7 +117,7 @@ public class MixedTensor implements Tensor {
return index.denseSubspaceSize();
}
-
+
/**
* Base class for building mixed tensors.
*/
@@ -286,7 +286,7 @@ public class MixedTensor implements Tensor {
}
return typeBuilder.build();
}
-
+
}
/**
@@ -360,7 +360,7 @@ public class MixedTensor implements Tensor {
}
return denseSubspaceSize;
}
-
+
private TensorAddress sparsePartialAddress(TensorAddress address) {
if (type.dimensions().size() != address.size()) {
throw new IllegalArgumentException("Tensor type and address are not of same size.");
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 2ed211539d8..1b60e01cf7e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -52,7 +52,7 @@ import java.util.function.Function;
public interface Tensor {
// ----------------- Accessors
-
+
TensorType type();
/** Returns whether this have any cells */
@@ -70,13 +70,13 @@ public interface Tensor {
/** Returns the values of this in some undefined order */
Iterator<Double> valueIterator();
- /**
+ /**
* Returns an immutable map of the cells of this in no particular order.
- * This may be expensive for some implementations - avoid when possible
+ * This may be expensive for some implementations - avoid when possible
*/
Map<TensorAddress, Double> cells();
- /**
+ /**
* Returns the value of this as a double if it has no dimensions and one value
*
* @throws IllegalStateException if this does not have zero dimensions and one value
@@ -87,9 +87,9 @@ public interface Tensor {
if (size() == 0) return Double.NaN;
return valueIterator().next();
}
-
+
// ----------------- Primitive tensor functions
-
+
default Tensor map(DoubleUnaryOperator mapper) {
return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate();
}
@@ -108,7 +108,7 @@ public interface Tensor {
}
default Tensor rename(String fromDimension, String toDimension) {
- return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension),
+ return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension),
Collections.singletonList(toDimension)).evaluate();
}
@@ -123,13 +123,13 @@ public interface Tensor {
default Tensor rename(List<String> fromDimensions, List<String> toDimensions) {
return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate();
}
-
+
static Tensor generate(TensorType type, Function<List<Integer>, Double> valueSupplier) {
return new Generate(type, valueSupplier).evaluate();
}
-
+
// ----------------- Composite tensor functions which have a defined primitive mapping
-
+
default Tensor l1Normalize(String dimension) {
return new L1Normalize(new ConstantTensor(this), dimension).evaluate();
}
@@ -231,7 +231,7 @@ public interface Tensor {
if (cellEntries.isEmpty()) return "{}";
return "{" + cellEntries.get(0).getValue() +"}";
}
-
+
Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());
StringBuilder b = new StringBuilder("{");
@@ -253,7 +253,7 @@ public interface Tensor {
*/
boolean equals(Object o);
- /**
+ /**
* Implement here to make this work across implementations.
* Implementations must override equals and call this because this is an interface and cannot override equals.
*/
@@ -328,13 +328,13 @@ public interface Tensor {
@Override
public TensorAddress getKey() { return address; }
- /**
+ /**
* Returns the direct index which can be used to locate this cell, or -1 if not available.
* This is for optimizations mapping between tensors where this is possible without creating a
* TensorAddress.
*/
int getDirectIndex() { return -1; }
-
+
@Override
public Double getValue() { return value; }
@@ -388,20 +388,20 @@ public interface Tensor {
/** Returns the type this is building */
TensorType type();
-
+
/** Return a cell builder */
CellBuilder cell();
/** Add a cell */
Builder cell(TensorAddress address, double value);
-
+
/** Add a cell */
Builder cell(double value, int ... labels);
- /**
- * Add a cell
- *
- * @param cell a cell providing the location at which to add this cell
+ /**
+ * Add a cell
+ *
+ * @param cell a cell providing the location at which to add this cell
* @param value the value to assign to the cell
*/
default Builder cell(Cell cell, double value) {
@@ -409,12 +409,12 @@ public interface Tensor {
}
Tensor build();
-
+
class CellBuilder {
private final TensorAddress.Builder addressBuilder;
private final Tensor.Builder tensorBuilder;
-
+
CellBuilder(TensorType type, Tensor.Builder tensorBuilder) {
addressBuilder = new TensorAddress.Builder(type);
this.tensorBuilder = tensorBuilder;
@@ -436,5 +436,5 @@ public interface Tensor {
}
}
-
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 7161450d5d5..ff1202463f2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -32,10 +32,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
/** Returns the number of labels in this */
public abstract int size();
-
+
/**
- * Returns the i'th label in this
- *
+ * Returns the i'th label in this
+ *
* @throws IllegalArgumentException if there is no label at this index
*/
public abstract String label(int i);
@@ -102,23 +102,23 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
private StringTensorAddress(String ... labels) {
this.labels = Arrays.copyOf(labels, labels.length);
}
-
+
@Override
public int size() { return labels.length; }
-
+
@Override
public String label(int i) { return labels[i]; }
-
+
@Override
- public int intLabel(int i) {
+ public int intLabel(int i) {
try {
return Integer.parseInt(labels[i]);
- }
+ }
catch (NumberFormatException e) {
throw new IllegalArgumentException("Expected an int label in " + this + " at position " + i);
}
}
-
+
@Override
public TensorAddress withLabel(int index, int label) {
String[] labels = Arrays.copyOf(this.labels, this.labels.length);
@@ -169,7 +169,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
private final TensorType type;
private final String[] labels;
-
+
public Builder(TensorType type) {
this(type, new String[type.dimensions().size()]);
}
@@ -193,7 +193,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
labels[labelIndex.get()] = label;
return this;
}
-
+
/** Creates a copy of this which can be modified separately */
public Builder copy() {
return new Builder(type, Arrays.copyOf(labels, labels.length));
@@ -202,7 +202,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
public TensorAddress build() {
for (int i = 0; i < labels.length; i++)
if (labels[i] == null)
- throw new IllegalArgumentException("Missing a value for dimension " +
+ throw new IllegalArgumentException("Missing a value for dimension " +
type.dimensions().get(i).name() + " for " + type);
return TensorAddress.of(labels);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index da8ab3bb0ec..9b3a9328f07 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -96,7 +96,7 @@ class TensorParser {
if (valueEnd < 0)
throw new IllegalArgumentException("A tensor string must end by '}'");
}
-
+
TensorAddress address = addressBuilder.build();
Double value = asDouble(address, s.substring(0, valueEnd).trim());
builder.cell(address, value);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index c05c35d6df3..914d853aeca 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -53,14 +53,17 @@ public class TensorType {
return TensorTypeParser.fromSpec(specString);
}
+ /** Returns the number of dimensions of this: dimensions().size() */
+ public int rank() { return dimensions.size(); }
+
/** Returns an immutable list of the dimensions of this */
public List<Dimension> dimensions() { return dimensions; }
-
+
/** Returns an immutable set of the names of the dimensions of this */
public Set<String> dimensionNames() {
return dimensions.stream().map(Dimension::name).collect(Collectors.toSet());
}
-
+
/** Returns the dimension with this name, or empty if not present */
public Optional<Dimension> dimension(String name) {
return indexOfDimension(name).map(i -> dimensions.get(i));
@@ -74,7 +77,7 @@ public class TensorType {
return Optional.empty();
}
- /**
+ /**
* Returns whether this type can be assigned to the given type,
* i.e if the given type is a generalization of this type.
*/
@@ -128,9 +131,9 @@ public class TensorType {
private final String name;
- private Dimension(String name) {
+ private Dimension(String name) {
Objects.requireNonNull(name, "A tensor name cannot be null");
- this.name = name;
+ this.name = name;
}
public final String name() { return name; }
@@ -146,7 +149,7 @@ public class TensorType {
/** Returns true if this is an indexed bound or unboun type */
public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; }
- /**
+ /**
* Returns the dimension resulting from combining two dimensions having the same name but possibly different
* types. This works by degrading to the type making the fewer promises.
* [N] + [M] = [min(N, M)]
@@ -165,7 +168,7 @@ public class TensorType {
IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get();
return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
}
-
+
@Override
public abstract String toString();
@@ -175,21 +178,21 @@ public class TensorType {
if (other == null || getClass() != other.getClass()) return false;
return name.equals(((Dimension)other).name);
}
-
+
@Override
public int hashCode() {
return name.hashCode();
}
-
+
@Override
public int compareTo(Dimension other) {
return this.name.compareTo(other.name);
}
-
+
public static Dimension indexed(String name, int size) {
return new IndexedBoundDimension(name, size);
}
-
+
}
public static class IndexedBoundDimension extends TensorType.Dimension {
@@ -289,9 +292,9 @@ public class TensorType {
public Builder() {
}
- /**
- * Creates a builder containing a combination of the dimensions of the given types
- *
+ /**
+ * Creates a builder containing a combination of the dimensions of the given types
+ *
* If the same dimension is indexed with different size restrictions the largest size will be used.
* If it is size restricted in one argument but not the other it will not be size restricted.
* If it is indexed in one and mapped in the other it will become mapped.
@@ -325,9 +328,12 @@ public class TensorType {
}
}
- /**
+ /** Returns the current number of dimensions in this */
+ public int rank() { return dimensions.size(); }
+
+ /**
* Adds a new dimension to this
- *
+ *
* @throws IllegalArgumentException if the dimension is already present
*/
private Builder add(Dimension dimension) {
@@ -346,7 +352,7 @@ public class TensorType {
return this;
}
- /**
+ /**
* Adds a bound indexed dimension to this
*
* @throws IllegalArgumentException if the dimension is already present
@@ -355,7 +361,7 @@ public class TensorType {
/**
* Adds an unbound indexed dimension to this
- *
+ *
* @throws IllegalArgumentException if the dimension is already present
*/
public Builder indexed(String name) {
@@ -375,7 +381,7 @@ public class TensorType {
public Builder dimension(Dimension dimension) {
return add(dimension);
}
-
+
/** Returns the given dimension, or empty if none is present */
public Optional<Dimension> getDimension(String dimension) {
return Optional.ofNullable(dimensions.get(dimension));
@@ -393,7 +399,7 @@ public class TensorType {
public TensorType build() {
return new TensorType(dimensions.values());
}
-
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
index 84caca78fb2..3db661f8a23 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
@@ -2,16 +2,17 @@
package com.yahoo.tensor.evaluation;
import com.google.common.annotations.Beta;
-
-import java.util.HashMap;
+import com.yahoo.tensor.Tensor;
/**
* An evaluation context which is passed down to all nested functions during evaluation.
- * The default context is empty to allow various evaluation frameworks to support their own implementation.
- *
+ *
* @author bratseth
*/
@Beta
public interface EvaluationContext {
+ /** Returns the tensor bound to this name, or null if none */
+ Tensor getTensor(String name);
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
index cf704c15f4f..db8a66a5fa2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
@@ -18,7 +18,7 @@ public class MapEvaluationContext implements EvaluationContext {
public void put(String name, Tensor tensor) { bindings.put(name, tensor); }
- /** Returns the tensor bound to this name, or null if none */
- public Tensor get(String name) { return bindings.get(name); }
+ @Override
+ public Tensor getTensor(String name) { return bindings.get(name); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index 8ade181bdb7..1f6ad050368 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -12,18 +12,18 @@ import java.util.List;
/**
* A tensor variable name which resolves to a tensor in the context at evaluation time
- *
+ *
* @author bratseth
*/
@Beta
public class VariableTensor extends PrimitiveTensorFunction {
private final String name;
-
+
public VariableTensor(String name) {
this.name = name;
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -35,7 +35,7 @@ public class VariableTensor extends PrimitiveTensorFunction {
@Override
public Tensor evaluate(EvaluationContext context) {
- return ((MapEvaluationContext)context).get(name);
+ return context.getTensor(name);
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 8f4dbf014a7..191c7988443 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -8,7 +8,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
/**
* A composite tensor function is a tensor function which can be expressed (less tersely)
* as a tree of primitive tensor functions.
- *
+ *
* @author bratseth
*/
@Beta
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 1dbb94fdb20..faa0ca36cb6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -15,7 +15,7 @@ import java.util.stream.Collectors;
/**
* Concatenation of two tensors along an (indexed) dimension
- *
+ *
* @author bratseth
*/
@Beta
@@ -74,7 +74,7 @@ public class Concat extends PrimitiveTensorFunction {
concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
return builder.build();
}
-
+
private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
@@ -112,7 +112,7 @@ public class Concat extends PrimitiveTensorFunction {
Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build();
return tensor.multiply(unitTensor);
}
-
+
}
/** Returns the type resulting from concatenating a and b */
@@ -144,7 +144,7 @@ public class Concat extends PrimitiveTensorFunction {
/**
* Combine two addresses, adding the offset to the concat dimension
*
- * @return the combined address or null if the addresses are incompatible
+ * @return the combined address or null if the addresses are incompatible
* (in some other dimension than the concat dimension)
*/
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
@@ -161,7 +161,7 @@ public class Concat extends PrimitiveTensorFunction {
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
+ * That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 4ac7b21ba90..14ed38718ce 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -10,18 +10,18 @@ import java.util.List;
/**
* A function which returns a constant tensor.
- *
+ *
* @author bratseth
*/
@Beta
public class ConstantTensor extends PrimitiveTensorFunction {
private final Tensor constant;
-
+
public ConstantTensor(String tensorString) {
this.constant = Tensor.from(tensorString);
}
-
+
public ConstantTensor(Tensor tensor) {
this.constant = tensor;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index bbdbd5c3df1..c75d8ee4753 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -11,19 +11,19 @@ import java.util.stream.Stream;
/**
* A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere.
- *
+ *
* @author bratseth
*/
public class Diag extends CompositeTensorFunction {
private final TensorType type;
private final Function<List<Integer>, Double> diagFunction;
-
+
public Diag(TensorType type) {
this.type = type;
this.diagFunction = ScalarFunctions.equal(dimensionNames().collect(Collectors.toList()));
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -43,7 +43,7 @@ public class Diag extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction;
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::name);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 6ea73b7f310..e42d25197e2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -15,7 +15,7 @@ import java.util.function.Function;
/**
* An indexed tensor whose values are generated by a function
- *
+ *
* @author bratseth
*/
@Beta
@@ -26,7 +26,7 @@ public class Generate extends PrimitiveTensorFunction {
/**
* Creates a generated tensor
- *
+ *
* @param type the type of the tensor
* @param generator the function generating values from a list of ints specifying the indexes of the
* tensor cell which will receive the value
@@ -39,7 +39,7 @@ public class Generate extends PrimitiveTensorFunction {
this.type = type;
this.generator = generator;
}
-
+
private void validateType(TensorType type) {
for (TensorType.Dimension dimension : type.dimensions())
if (dimension.type() != TensorType.Dimension.Type.indexedBound)
@@ -58,7 +58,7 @@ public class Generate extends PrimitiveTensorFunction {
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
-
+
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor.Builder builder = Tensor.Builder.of(type);
@@ -69,14 +69,14 @@ public class Generate extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
private DimensionSizes dimensionSizes(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
for (int i = 0; i < b.dimensions(); i++)
b.set(i, type.dimensions().get(i).size().get());
return b.build();
}
-
+
@Override
public String toString(ToStringContext context) { return type + "(" + generator + ")"; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 8c4dbfb0acb..ff887e3e9a6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -28,12 +28,12 @@ import java.util.function.DoubleBinaryOperator;
* The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells
* given by the cross product of the cells of the given tensors, having as values the value produced by
* applying the given combinator function on the values from the two source cells.
- *
+ *
* @author bratseth
*/
@Beta
public class Join extends PrimitiveTensorFunction {
-
+
private final TensorFunction argumentA, argumentB;
private final DoubleBinaryOperator combinator;
@@ -46,6 +46,30 @@ public class Join extends PrimitiveTensorFunction {
this.combinator = combinator;
}
+ /** Returns the type resulting from applying Join to the two given types */
+ public static TensorType outputType(TensorType a, TensorType b) {
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (int i = 0; i < a.dimensions().size(); ++i) {
+ TensorType.Dimension aDim = a.dimensions().get(i);
+ for (int j = 0; j < b.dimensions().size(); ++j) {
+ TensorType.Dimension bDim = b.dimensions().get(j);
+ if (aDim.name().equals(bDim.name())) { // include
+ if (aDim.isIndexed() && bDim.isIndexed()) {
+ if (aDim.size().isPresent() || bDim.size().isPresent())
+ typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
+ bDim.size().orElse(Integer.MAX_VALUE)));
+ else
+ typeBuilder.indexed(aDim.name());
+ }
+ else {
+ typeBuilder.mapped(aDim.name());
+ }
+ }
+ }
+ }
+ return typeBuilder.build();
+ }
+
public TensorFunction argumentA() { return argumentA; }
public TensorFunction argumentB() { return argumentB; }
public DoubleBinaryOperator combinator() { return combinator; }
@@ -88,11 +112,11 @@ public class Join extends PrimitiveTensorFunction {
else
return generalJoin(a, b, joinedType);
}
-
+
private boolean hasSingleIndexedDimension(Tensor tensor) {
return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
}
-
+
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
@@ -114,7 +138,7 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
@@ -126,7 +150,7 @@ public class Join extends PrimitiveTensorFunction {
private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
-
+
DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace);
IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes);
@@ -134,14 +158,14 @@ public class Join extends PrimitiveTensorFunction {
// Find dimensions which are only in the supertype
Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames());
superDimensionNames.removeAll(subspace.type().dimensionNames());
-
+
for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
joinSubspaces(subspace.valueIterator(), subspace.size(),
subspaceInSuper, subspaceInSuper.size(),
reversedArgumentOrder, builder);
}
-
+
return builder.build();
}
@@ -200,7 +224,7 @@ public class Join extends PrimitiveTensorFunction {
subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
return subspaceIndexes;
}
-
+
private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
String[] subspaceLabels = new String[subspaceIndexes.length];
for (int i = 0; i < subspaceIndexes.length; i++)
@@ -235,7 +259,7 @@ public class Join extends PrimitiveTensorFunction {
DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
// for each combination of dimensions only in a
- for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
+ for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
IndexedTensor.SubspaceIterator aSubspace = ia.next();
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
@@ -252,7 +276,7 @@ public class Join extends PrimitiveTensorFunction {
}
}
}
-
+
private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
for (int i = 0; i < addressType.dimensions().size(); i++)
@@ -260,7 +284,7 @@ public class Join extends PrimitiveTensorFunction {
builder.add(addressType.dimensions().get(i).name(), address.intLabel(i));
return builder.build();
}
-
+
/** Returns the sizes from the joined sizes which are present in the type argument */
private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
@@ -271,7 +295,7 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
@@ -340,7 +364,7 @@ public class Join extends PrimitiveTensorFunction {
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
+ * That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
@@ -360,7 +384,7 @@ public class Join extends PrimitiveTensorFunction {
return TensorAddress.of(joinedLabels);
}
- /**
+ /**
* Maps the content in the given list to the given array, using the given index map.
*
* @return true if the mapping was successful, false if one of the destination positions was
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index a9872bb42d8..a5e1a016a41 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -6,6 +6,7 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Collections;
@@ -32,6 +33,8 @@ public class Map extends PrimitiveTensorFunction {
this.mapper = mapper;
}
+ public static TensorType outputType(TensorType inputType) { return inputType; }
+
public TensorFunction argument() { return argument; }
public DoubleUnaryOperator mapper() { return mapper; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index bb27e937699..4071917c2b5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.TensorType;
import java.util.List;
@@ -14,13 +15,17 @@ public class Matmul extends CompositeTensorFunction {
private final TensorFunction argument1, argument2;
private final String dimension;
-
+
public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) {
this.argument1 = argument1;
this.argument2 = argument2;
this.dimension = dimension;
}
+ public static TensorType outputType(TensorType a, TensorType b, String dimension) {
+ return Join.outputType(a, b);
+ }
+
@Override
public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); }
@@ -39,7 +44,7 @@ public class Matmul extends CompositeTensorFunction {
Reduce.Aggregator.sum,
dimension);
}
-
+
@Override
public String toString(ToStringContext context) {
return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
index efb7b9e500c..b7c9a5d2342 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
@@ -8,10 +8,10 @@ import com.yahoo.tensor.Tensor;
* A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions.
* All tensor implementations must implement all primitive tensor functions.
* Primitive tensor functions are fully inspectable.
- *
+ *
* @author bratseth
*/
@Beta
public abstract class PrimitiveTensorFunction extends TensorFunction {
-
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 457763e97ba..958ef85d1dc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -22,11 +22,11 @@ import java.util.stream.Stream;
public class Random extends CompositeTensorFunction {
private final TensorType type;
-
+
public Random(TensorType type) {
this.type = type;
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -46,7 +46,7 @@ public class Random extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")";
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index e2b39a2048d..a56f82b026a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -12,19 +12,19 @@ import java.util.stream.Stream;
/**
* A tensor generator which returns a tensor of any dimension filled with the sum of the tensor
* indexes of each position.
- *
+ *
* @author bratseth
*/
public class Range extends CompositeTensorFunction {
private final TensorType type;
private final Function<List<Integer>, Double> rangeFunction;
-
+
public Range(TensorType type) {
this.type = type;
this.rangeFunction = ScalarFunctions.sum(dimensionNames().collect(Collectors.toList()));
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -44,7 +44,7 @@ public class Range extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction;
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index cfc78be7e0c..de9f90a5804 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -19,7 +19,7 @@ import java.util.Objects;
import java.util.Set;
/**
- * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
+ * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
* are collapsed to a single value using an aggregator function.
*
* @author bratseth
@@ -45,7 +45,7 @@ public class Reduce extends PrimitiveTensorFunction {
/**
* Creates a reduce function.
- *
+ *
* @param argument the tensor to reduce
* @param aggregator the aggregator function to use
* @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced,
@@ -61,6 +61,15 @@ public class Reduce extends PrimitiveTensorFunction {
this.dimensions = ImmutableList.copyOf(dimensions);
}
+ public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
+ TensorType.Builder b = new TensorType.Builder();
+ for (TensorType.Dimension dimension : inputType.dimensions()) {
+ if ( ! reduceDimensions.contains(dimension.name()))
+ b.dimension(dimension);
+ }
+ return b.build();
+ }
+
public TensorFunction argument() { return argument; }
@Override
@@ -82,7 +91,7 @@ public class Reduce extends PrimitiveTensorFunction {
public String toString(ToStringContext context) {
return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
}
-
+
private String commaSeparated(List<String> list) {
StringBuilder b = new StringBuilder();
for (String element : list)
@@ -94,7 +103,7 @@ public class Reduce extends PrimitiveTensorFunction {
public Tensor evaluate(EvaluationContext context) {
Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
- throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
+ throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
dimensions + ": Not all those dimensions are present in this tensor");
// Special case: Reduce all
@@ -103,14 +112,14 @@ public class Reduce extends PrimitiveTensorFunction {
return reduceIndexedVector((IndexedTensor)argument);
else
return reduceAllGeneral(argument);
-
+
// Reduce type
TensorType.Builder builder = new TensorType.Builder();
for (TensorType.Dimension dimension : argument.type().dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
builder.dimension(dimension);
TensorType reducedType = builder.build();
-
+
// Reduce cells
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
@@ -122,10 +131,10 @@ public class Reduce extends PrimitiveTensorFunction {
Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet())
reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
-
+
return reducedBuilder.build();
}
-
+
private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) {
Set<Integer> indexesToRemove = new HashSet<>();
for (String dimensionToRemove : this.dimensions)
@@ -138,7 +147,7 @@ public class Reduce extends PrimitiveTensorFunction {
reducedLabels[reducedLabelIndex++] = address.label(i);
return TensorAddress.of(reducedLabels);
}
-
+
private Tensor reduceAllGeneral(Tensor argument) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); )
@@ -154,7 +163,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
private static abstract class ValueAggregator {
-
+
private static ValueAggregator ofType(Aggregator aggregator) {
switch (aggregator) {
case avg : return new AvgAggregator();
@@ -165,22 +174,22 @@ public class Reduce extends PrimitiveTensorFunction {
case min : return new MinAggregator();
default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
}
-
+
}
/** Add a new value to those aggregated by this */
public abstract void aggregate(double value);
-
+
/** Returns the value aggregated by this */
public abstract double aggregatedValue();
-
+
}
-
+
private static class AvgAggregator extends ValueAggregator {
private int valueCount = 0;
private double valueSum = 0.0;
-
+
@Override
public void aggregate(double value) {
valueCount++;
@@ -188,7 +197,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public double aggregatedValue() {
+ public double aggregatedValue() {
return valueSum / valueCount;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 6b0daf1b49a..ec9b762a41c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -3,8 +3,6 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -19,7 +17,7 @@ import java.util.Objects;
/**
* The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names.
- *
+ *
* @author bratseth
*/
@Beta
@@ -29,6 +27,10 @@ public class Rename extends PrimitiveTensorFunction {
private final List<String> fromDimensions;
private final List<String> toDimensions;
+ public Rename(TensorFunction argument, String fromDimension, String toDimension) {
+ this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
+ }
+
public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
@@ -42,7 +44,7 @@ public class Rename extends PrimitiveTensorFunction {
this.fromDimensions = ImmutableList.copyOf(fromDimensions);
this.toDimensions = ImmutableList.copyOf(toDimensions);
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@@ -62,7 +64,7 @@ public class Rename extends PrimitiveTensorFunction {
Map<String, String> fromToMap = fromToMap();
TensorType renamedType = rename(tensor.type(), fromToMap);
-
+
// an array which lists the index of each label in the renamed type
int[] toIndexes = new int[tensor.type().dimensions().size()];
for (int i = 0; i < tensor.type().dimensions().size(); i++) {
@@ -70,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction {
String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName);
toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get();
}
-
+
Tensor.Builder builder = Tensor.Builder.of(renamedType);
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
@@ -86,7 +88,7 @@ public class Rename extends PrimitiveTensorFunction {
builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
return builder.build();
}
-
+
private TensorAddress rename(TensorAddress address, int[] toIndexes) {
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)
@@ -95,18 +97,18 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public String toString(ToStringContext context) {
- return "rename(" + argument.toString(context) + ", " +
+ public String toString(ToStringContext context) {
+ return "rename(" + argument.toString(context) + ", " +
toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
}
-
+
private Map<String, String> fromToMap() {
Map<String, String> map = new HashMap<>();
for (int i = 0; i < fromDimensions.size(); i++)
map.put(fromDimensions.get(i), toDimensions.get(i));
return map;
}
-
+
private String toVectorString(List<String> elements) {
if (elements.size() == 1)
return elements.get(0);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 99f79cb735a..fb5029fbfd6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -21,101 +21,87 @@ import java.util.stream.Collectors;
@Beta
public class ScalarFunctions {
- public static DoubleBinaryOperator add() { return new Addition(); }
- public static DoubleBinaryOperator multiply() { return new Multiplication(); }
- public static DoubleBinaryOperator divide() { return new Division(); }
+ public static DoubleBinaryOperator add() { return new Add(); }
+ public static DoubleBinaryOperator divide() { return new Divide(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
- public static DoubleUnaryOperator square() { return new Square(); }
+ public static DoubleBinaryOperator multiply() { return new Multiply(); }
+
+ public static DoubleUnaryOperator acos() { return new Acos(); }
+ public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
- public static DoubleUnaryOperator exp() { return new Exponent(); }
+ public static DoubleUnaryOperator square() { return new Square(); }
+
public static Function<List<Integer>, Double> random() { return new Random(); }
public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
- public static class Addition implements DoubleBinaryOperator {
+ // Binary operators -----------------------------------------------------------------------------
+ public static class Add implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left + right; }
-
@Override
public String toString() { return "f(a,b)(a + b)"; }
-
}
- public static class Multiplication implements DoubleBinaryOperator {
-
+ public static class Equal implements DoubleBinaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left * right; }
-
+ public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
@Override
- public String toString() { return "f(a,b)(a * b)"; }
-
+ public String toString() { return "f(a,b)(a==b)"; }
}
- public static class Division implements DoubleBinaryOperator {
-
+ public static class Exp implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left / right; }
-
+ public double applyAsDouble(double operand) { return Math.exp(operand); }
@Override
- public String toString() { return "f(a,b)(a / b)"; }
+ public String toString() { return "f(a)(exp(a))"; }
}
- public static class Equal implements DoubleBinaryOperator {
-
+ public static class Multiply implements DoubleBinaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
+ public double applyAsDouble(double left, double right) { return left * right; }
+ @Override
+ public String toString() { return "f(a,b)(a * b)"; }
+ }
+ public static class Divide implements DoubleBinaryOperator {
@Override
- public String toString() { return "f(a,b)(a==b)"; }
+ public double applyAsDouble(double left, double right) { return left / right; }
+ @Override
+ public String toString() { return "f(a,b)(a / b)"; }
}
- public static class Square implements DoubleUnaryOperator {
+ // Unary operators ------------------------------------------------------------------------------
+ public static class Acos implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double operand) { return operand * operand; }
-
+ public double applyAsDouble(double operand) { return Math.acos(operand); }
@Override
- public String toString() { return "f(a)(a * a)"; }
-
+ public String toString() { return "f(a)(acos(a))"; }
}
public static class Sqrt implements DoubleUnaryOperator {
-
@Override
public double applyAsDouble(double operand) { return Math.sqrt(operand); }
-
@Override
public String toString() { return "f(a)(sqrt(a))"; }
-
}
- public static class Exponent implements DoubleUnaryOperator {
+ public static class Square implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double operand) { return Math.exp(operand); }
+ public double applyAsDouble(double operand) { return operand * operand; }
@Override
- public String toString() { return "f(a)(exp(a))"; }
+ public String toString() { return "f(a)(a * a)"; }
}
- public static class Random implements Function<List<Integer>, Double> {
-
- @Override
- public Double apply(List<Integer> values) {
- return ThreadLocalRandom.current().nextDouble();
- }
-
- @Override
- public String toString() { return "random"; }
+ // Variable-length operators -----------------------------------------------------------------------------
- }
-
- public static class EqualElements implements Function<List<Integer>, Double> {
-
- private final ImmutableList<String> argumentNames;
-
+ public static class EqualElements implements Function<List<Integer>, Double> {
+ private final ImmutableList<String> argumentNames;
private EqualElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@@ -128,7 +114,6 @@ public class ScalarFunctions {
return 0.0;
return 1.0;
}
-
@Override
public String toString() {
if (argumentNames.size() == 0) return "1";
@@ -143,13 +128,19 @@ public class ScalarFunctions {
}
return b.toString();
}
+ }
+ public static class Random implements Function<List<Integer>, Double> {
+ @Override
+ public Double apply(List<Integer> values) {
+ return ThreadLocalRandom.current().nextDouble();
+ }
+ @Override
+ public String toString() { return "random"; }
}
public static class SumElements implements Function<List<Integer>, Double> {
-
private final ImmutableList<String> argumentNames;
-
private SumElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@@ -161,12 +152,10 @@ public class ScalarFunctions {
sum += value;
return (double)sum;
}
-
@Override
public String toString() {
return argumentNames.stream().collect(Collectors.joining("+"));
}
-
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index bf279eb24d8..c856b548180 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -2,6 +2,8 @@
package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.TensorType;
import java.util.Collections;
import java.util.List;
@@ -19,6 +21,10 @@ public class Softmax extends CompositeTensorFunction {
this.argument = argument;
this.dimension = dimension;
}
+
+ public static TensorType outputType(TensorType inputType, String dimension) {
+ return Reduce.outputType(inputType, ImmutableList.of(dimension));
+ }
@Override
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index cabcce198d1..533a46f87fe 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -12,7 +12,7 @@ import java.util.List;
* A representation of a tensor function which is able to be translated to a set of primitive
* tensor functions if necessary.
* All tensor functions are immutable.
- *
+ *
* @author bratseth
*/
@Beta
@@ -48,11 +48,11 @@ public abstract class TensorFunction {
/**
* Return a string representation of this context.
- *
+ *
* @param context a context which must be passed to all nexted functions when requesting the string value
*/
public abstract String toString(ToStringContext context);
-
+
@Override
public String toString() { return toString(ToStringContext.empty()); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
index e8c425d49e0..416b28afa22 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
@@ -24,7 +24,7 @@ interface BinaryFormat {
/**
* Deserialize the given binary data into a Tensor object.
- *
+ *
* @param type the expected abstract type of the tensor to serialize, or empty to use type information from the data
* @param buffer the buffer containing the tensor binary data
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 8b7325ec211..aabb53d1c67 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -16,9 +16,9 @@ import java.util.Optional;
*
* Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]*
* Cell_values = [double, double, double, ...]*
- * where values are encoded in order of increasing indexes in each dimension, increasing
+ * where values are encoded in order of increasing indexes in each dimension, increasing
* indexes of later dimensions in the dimension type before earlier.
- *
+ *
* @author bratseth
*/
@Beta
@@ -54,7 +54,7 @@ public class DenseBinaryFormat implements BinaryFormat {
type = optionalType.get();
TensorType serializedType = decodeType(buffer);
if ( ! serializedType.isAssignableTo(type))
- throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
+ throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
" cannot be assigned to type " + type);
sizes = sizesFromType(serializedType);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 7467554790a..01a1d023f2b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -46,9 +46,9 @@ public class TypedBinaryFormat {
return result;
}
- /**
- * Decode some data to a tensor
- *
+ /**
+ * Decode some data to a tensor
+ *
* @param type the type to decode and validate to, or empty to use the type given in the data
* @param buffer the buffer containing the data, use GrowableByteByffer.wrap(byte[]) if you have a byte array
* @return the resulting tensor
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
index d199dd3a876..abdb3071bf7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
@@ -13,14 +13,14 @@ import java.util.stream.Collectors;
/**
* Microbenchmark of tensor operations.
- *
+ *
* @author bratseth
*/
public class TensorFunctionBenchmark {
private final static Random random = new Random();
-
- public double benchmark(int iterations, List<Tensor> modelVectors, TensorType.Dimension.Type dimensionType,
+
+ public double benchmark(int iterations, List<Tensor> modelVectors, TensorType.Dimension.Type dimensionType,
boolean extraSpace) {
Tensor queryVector = vectors(1, 300, dimensionType).get(0);
if (extraSpace) {
@@ -34,7 +34,7 @@ public class TensorFunctionBenchmark {
long totalTime = System.currentTimeMillis() - startTime;
return (double)totalTime / (double)iterations;
}
-
+
private Tensor unitVector(String dimension) {
return Tensor.Builder.of(new TensorType.Builder().indexed(dimension, 1).build())
.cell().label(dimension, 0).value(1).build();
@@ -49,11 +49,11 @@ public class TensorFunctionBenchmark {
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double largest = Double.MIN_VALUE;
- TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
- new VariableTensor("argument"), (a, b) -> a * b),
+ TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
+ new VariableTensor("argument"), (a, b) -> a * b),
Reduce.Aggregator.sum).toPrimitive();
MapEvaluationContext context = new MapEvaluationContext();
-
+
for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor
context.put("argument", tensorElement);
double dotProduct = dotProductFunction.evaluate(context).asDouble();
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 30078b4a826..693b0f09351 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -25,7 +25,7 @@ import static org.junit.Assert.fail;
/**
* Tests tensor functionality
- *
+ *
* @author bratseth
*/
public class TensorTestCase {
@@ -108,7 +108,7 @@ public class TensorTestCase {
Tensor.diag(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build()));
assertEquals(Tensor.from("{ {x:1}:0, {x:3}:1, {x:9}:0 }"), Tensor.from("{ {x:1}:1, {x:3}:5, {x:9}:3 }").argmax("x"));
}
-
+
/** Test the same computation made in various ways which are implemented with special-case optimizations */
@Test
public void testOptimizedComputation() {
@@ -130,7 +130,7 @@ public class TensorTestCase {
assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.indexedUnbound), vectors(Type.mapped, 2)));
assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2)));
assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2)));
-
+
// Test the unoptimized path by joining in another dimension
Tensor unitJ = Tensor.Builder.of(new TensorType.Builder().mapped("j").build()).cell().label("j", 0).value(1).build();
Tensor unitK = Tensor.Builder.of(new TensorType.Builder().mapped("k").build()).cell().label("k", 0).value(1).build();
@@ -138,7 +138,7 @@ public class TensorTestCase {
Tensor matrixInKSpace = matrix(Type.mapped, 2).get(0).multiply(unitK);
assertEquals("Generic computation implementation", 42, (int)dotProduct(vectorInJSpace, Collections.singletonList(matrixInKSpace)));
}
-
+
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double sum = 0;
TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
@@ -161,7 +161,7 @@ public class TensorTestCase {
private Tensor vector(int vectorSize, TensorType.Dimension.Type dimensionType) {
return vectors(vectorSize, dimensionType, 1).get(0);
}
-
+
/** Create a list of vectors having a single dimension x */
private List<Tensor> vectors(TensorType.Dimension.Type dimensionType, int vectorCount) {
return vectors(3, dimensionType, vectorCount);
@@ -179,8 +179,8 @@ public class TensorTestCase {
}
return tensors;
}
-
- /**
+
+ /**
* Create a matrix of vectors (in dimension i) where each vector has the dimension x.
* This matrix contains the same vectors as returned by createVectors, in a single list element for convenience.
*/
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java
index fab53218b2c..f11c068bd74 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java
@@ -10,12 +10,12 @@ import static org.junit.Assert.assertEquals;
* @author bratseth
*/
public class JoinTestCase {
-
+
/** Test the indexed subspace join optimization */
@Test
public void testJoinIndexedSubspace() {
Tensor t1, t2;
-
+
t1 = Tensor.from("tensor(x[]):{{x:0}:1.0,{x:1}:2.0}");
t2 = Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:10,{x:1,y:1,z:0}:0.0}");
assertEquals(Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:20.0,{x:1,y:1,z:0}:0.0}"),
@@ -34,10 +34,10 @@ public class JoinTestCase {
assertEquals(Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:10.0,{x:1,y:1,z:0}:0.0}"),
t2.divide(t1));
}
-
+
@Test
public void testGeneralJoin() {
- assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"),
+ assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"),
Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:4, {x:2}:6 }")
.divide(Tensor.from("tensor(y[]):{{y:0}:2}")));
@@ -45,5 +45,5 @@ public class JoinTestCase {
Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:6, {x:1,y:0}:8, {x:0,y:1}:20, {x:1,y:1}:24 }")
.divide(Tensor.from("tensor(y[],z[]):{ {y:0,z:0}:2, {y:1,z:0}:4, {y:2,z:0}:6 }")));
}
-
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
new file mode 100644
index 00000000000..9643c0a56e7
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
@@ -0,0 +1,97 @@
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class MatmulTestCase {
+
+ @Test
+ public void testMatmul2d() {
+ // d0 is the 'outermost' dimension, etc.
+ Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3])"));
+ ab.cell( 1,0, 0);
+ ab.cell( 2,0, 1);
+ ab.cell( 3,0, 2);
+ ab.cell( 4,1, 0);
+ ab.cell( 5,1, 1);
+ ab.cell( 6,1, 2);
+ Tensor a = ab.build();
+
+ Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[3],d1[2])"));
+ bb.cell( 7,0, 0);
+ bb.cell( 8,0, 1);
+ bb.cell( 9,1, 0);
+ bb.cell(10,1, 1);
+ bb.cell(11,2, 0);
+ bb.cell(12,2, 1);
+ Tensor b = bb.build();
+
+ Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2])"));
+ rb.cell( 58,0, 0);
+ rb.cell( 64,0, 1);
+ rb.cell(139,1, 0);
+ rb.cell(154,1, 1);
+ Tensor r = rb.build();
+
+ Tensor result = a.matmul(b.rename(ImmutableList.of("d0","d1"), ImmutableList.of("d1","d2")), "d1")
+ .rename("d2","d1");
+ assertEquals(r, result);
+ }
+
+ @Test
+ public void testMatmul3d() {
+ // Convention: a is the 'outermost' dimension, etc.
+ Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[3])"));
+ ab.cell( 1,0, 0, 0);
+ ab.cell( 2,0, 0, 1);
+ ab.cell( 3,0, 0, 2);
+ ab.cell( 4,0, 1, 0);
+ ab.cell( 5,0, 1, 1);
+ ab.cell( 6,0, 1, 2);
+ ab.cell( 7,1, 0, 0);
+ ab.cell( 8,1, 0, 1);
+ ab.cell( 9,1, 0, 2);
+ ab.cell(10,1, 1, 0);
+ ab.cell(11,1, 1, 1);
+ ab.cell(12,1, 1, 2);
+ Tensor a = ab.build();
+
+ Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3],d2[2])"));
+ bb.cell(13,0, 0, 0);
+ bb.cell(14,0, 0, 1);
+ bb.cell(15,0, 1, 0);
+ bb.cell(16,0, 1, 1);
+ bb.cell(17,0, 2, 0);
+ bb.cell(18,0, 2, 1);
+ bb.cell(19,1, 0, 0);
+ bb.cell(20,1, 0, 1);
+ bb.cell(21,1, 1, 0);
+ bb.cell(22,1, 1, 1);
+ bb.cell(23,1, 2, 0);
+ bb.cell(24,1, 2, 1);
+ Tensor b = bb.build();
+
+ Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[2])"));
+ rb.cell( 94,0, 0, 0);
+ rb.cell(100,0, 0, 1);
+ rb.cell(229,0, 1, 0);
+ rb.cell(244,0, 1, 1);
+ rb.cell(508,1, 0, 0);
+ rb.cell(532,1, 0, 1);
+ rb.cell(697,1, 1, 0);
+ rb.cell(730,1, 1, 1);
+ Tensor r = rb.build();
+
+ Tensor result = a.matmul(b.rename(ImmutableList.of("d1","d2"), ImmutableList.of("d2","d3")), "d2")
+ .rename("d3","d2");
+ assertEquals(r, result);
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
index 8a58cb0bbed..55069eaced7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
@@ -7,7 +7,7 @@ import static org.junit.Assert.assertEquals;
/**
* Tests translation of composite to primitive tensor function translation.
- *
+ *
* @author bratseth
*/
public class TensorFunctionTestCase {
@@ -16,12 +16,12 @@ public class TensorFunctionTestCase {
public void testTranslation() {
assertTranslated("join({{x:1}:1.0}, reduce({{x:1}:1.0}, sum, x), f(a,b)(a / b))",
new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x"));
- assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))",
+ assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))",
new Diag(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build()));
assertTranslated("join({{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, reduce({{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, max, x), f(a,b)(a==b))",
new Argmax(new ConstantTensor("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x"));
}
-
+
private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) {
assertEquals(expectedTranslation, inputFunction.toPrimitive().toString());
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index 349309a5052..15a872e439f 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -30,7 +30,7 @@ public class DenseBinaryFormatTestCase {
assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor(x[1],y[2],z[3]):{{y:0,x:0,z:0}:2.0}");
}
-
+
@Test
public void testSerializationToSeparateType() {
assertSerialization(Tensor.from("tensor(x[1],y[1]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[],y[])"));
@@ -64,7 +64,7 @@ public class DenseBinaryFormatTestCase {
private void assertSerialization(Tensor tensor) {
assertSerialization(tensor, tensor.type());
}
-
+
private void assertSerialization(Tensor tensor, TensorType expectedType) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor));
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
index b1d7d797b3e..33dfca017f4 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
@@ -84,7 +84,7 @@ public class MixedBinaryFormatTestCase {
private void assertSerialization(Tensor tensor) {
assertSerialization(tensor, tensor.type());
}
-
+
private void assertSerialization(Tensor tensor, TensorType expectedType) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor));
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
index 68bf59e3ed9..f002637847b 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
@@ -50,7 +50,7 @@ public class SerializationTestCase {
JsonNode node = mapper.readTree(test);
if (node.has("tensor") && node.has("binary")) {
System.out.println("Running test: " + test);
-
+
Tensor tensor = buildTensor(node.get("tensor"));
String spec = getSpec(node.get("tensor"));
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
@@ -123,7 +123,7 @@ public class SerializationTestCase {
private byte[] getBytes(String binaryRepresentation) {
return parseHexValue(binaryRepresentation.substring(2));
}
-
+
private byte[] parseHexValue(String s) {
final int len = s.length();
byte[] bytes = new byte[len/2];
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index d17148cf8dc..f895b64379b 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -65,7 +65,7 @@ public class SparseBinaryFormatTestCase {
private void assertSerialization(Tensor tensor, TensorType expectedType) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
- Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType),
+ Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType),
GrowableByteBuffer.wrap(encodedTensor));
assertEquals(tensor, decodedTensor);
}