diff options
author | gjoranv <gjoranv@gmail.com> | 2017-12-17 21:44:49 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-12-17 21:44:49 +0100 |
commit | 03bce1fe1a494f2ac9d4268d4c90b08011b3f600 (patch) | |
tree | 180f294d2ac97d641f0266216ffdc328db9bfef8 | |
parent | b72e55b87eecae006ed92976151137a80d75be0f (diff) |
Revert "Bratseth/tensorflow models"
91 files changed, 470 insertions, 7355 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java b/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java index 61cab2f6ce7..cee501841b4 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java @@ -4,9 +4,10 @@ package com.yahoo.config.application.api; import java.util.logging.Level; /** - * Used during application deployment to propagate messages to the end user + * Used during application deployment to persist and propagate messages to end user * - * @author Ulf Lillengen + * @author lulf + * @since 5.1 */ public interface DeployLogger { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java index 65176006a2a..69e353ceb35 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java @@ -118,7 +118,7 @@ public class TensorTransformer extends ExpressionTransformer { private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) { ExpressionNode arg1 = node.children().get(0); ExpressionNode arg2 = node.children().get(1); - + TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1); Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); String dimension = ((ReferenceNode) arg2).getName(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java index f932265cb93..c52a5dc465d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java @@ -259,7 +259,7 @@ public final class Attribute implements Cloneable, Serializable { throw new IllegalArgumentException("Field " + fieldType + " not supported in convertCollectionType"); } } - + private static Optional<TensorType> convertTensorType(DataType fieldType) { if ( ! ( fieldType instanceof TensorDataType)) return Optional.empty(); return Optional.of(((TensorDataType)fieldType).getTensorType()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java index 8b6df1a87db..c8918f39834 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java @@ -29,7 +29,7 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public <T extends Expression> boolean containsExpression(Class<T> searchFor) { - throw createUnsupportedException(searchFor.getSimpleName()); + throw createUnsupportedException(); } @Override @@ -79,9 +79,9 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public Index getIndex(String name) { - if ( ! importedField.fieldName().equals(name)) { + if (!importedField.fieldName().equals(name)) { throw new IllegalArgumentException("Getting an index (" + name + ") with different name than the imported field (" - + importedField.fieldName() + ") is not supported"); + + importedField.fieldName() + ") is not supported"); } String targetIndexName = importedField.targetField().getName(); return importedField.targetField().getIndex(targetIndexName); @@ -104,7 +104,7 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public ScriptExpression getIndexingScript() { - throw createUnsupportedException("indexing"); + throw createUnsupportedException(); } @Override @@ -119,12 +119,12 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public ImmutableSDField getStructField(String name) { - throw createUnsupportedException("struct"); + throw createUnsupportedException(); } @Override public Collection<? extends ImmutableSDField> getStructFields() { - throw createUnsupportedException("struct"); + throw createUnsupportedException(); } @Override @@ -134,12 +134,12 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public Stemming getStemming(Search search) { - throw createUnsupportedException("stemming"); + throw createUnsupportedException(); } @Override public Ranking getRanking() { - throw createUnsupportedException("ranking"); + throw createUnsupportedException(); } @Override @@ -158,8 +158,8 @@ public class ImmutableImportedSDField implements ImmutableSDField { importedField.targetField().getDataType()); } - private static UnsupportedOperationException createUnsupportedException(String aspect) { - return new UnsupportedOperationException("'" + aspect + "' is not meaningful or relevant for an imported field."); + private static UnsupportedOperationException createUnsupportedException() { + return new UnsupportedOperationException("This aspect is not meaningful or relevant for an imported field."); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java b/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java index 9368d6aaa39..96a9448739a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java +++ b/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java @@ -20,7 +20,7 @@ import java.util.Set; */ public class DocumentManager { - public DocumentmanagerConfig.Builder produce(DocumentModel model, + public DocumentmanagerConfig.Builder produce(DocumentModel model, DocumentmanagerConfig.Builder documentConfigBuilder) { documentConfigBuilder.enablecompression(false); Set<DataType> handled = new HashSet<>(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java index ce3c04f41f7..6eeb12ffdd9 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java @@ -45,7 +45,7 @@ public class ConstantTensorJsonValidator { throw new IllegalArgumentException("Ranking constant file names must end with either '.json' or '.json.lz4'"); } } - + private void validateTensor(TensorType type, Reader tensorData) { wrapIOException(() -> { this.parser = jsonFactory.createParser(tensorData); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java index c686f023d5b..4a9310799aa 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java @@ -64,7 +64,7 @@ public class RankingConstantsValidator extends Validator { private void validateRankingConstant(RankingConstant rankingConstant, ApplicationPackage applicationPackage) throws FileNotFoundException { ApplicationFile tensorApplicationFile = applicationPackage.getFile(Path.fromString(rankingConstant.getFileName())); - new ConstantTensorJsonValidator().validate(rankingConstant.getFileName(), + new ConstantTensorJsonValidator().validate(rankingConstant.getFileName(), rankingConstant.getTensorType(), tensorApplicationFile.createReader()); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java index e1675007bbc..a5b7d67e377 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java @@ -8,7 +8,6 @@ import com.yahoo.vespa.model.content.Redundancy; * Builds redundancy config for a content cluster. */ public class RedundancyBuilder { - Redundancy build(ModelElement clusterXml) { Integer initialRedundancy = 2; Integer finalRedundancy = 3; @@ -38,5 +37,4 @@ public class RedundancyBuilder { return new Redundancy(initialRedundancy, finalRedundancy, readyCopies); } - } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index 960a3b7d6db..9407c21fee8 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -173,7 +173,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.tensor3").isPresent()); assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.numeric").isPresent()); } - + private static Optional<String> findProperty(List<Pair<String, String>> properties, String key) { for (Pair<String, String> property : properties) if (property.getFirst().equals(key)) diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java index 4600f6ae4c6..7cd00e155bb 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java @@ -123,7 +123,7 @@ public class ExportingTestCase extends AbstractExportingTestCase { public void testIndexinfoFieldsets() throws IOException, ParseException { assertCorrectDeriving("indexinfo_fieldsets"); } - + @Test public void testStreamingJuniper() throws IOException, ParseException { assertCorrectDeriving("streamingjuniper"); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index e5693d24f0f..12bdd8d2b5c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -202,5 +202,5 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { } return b.toString(); } - + } diff --git a/container-dev/pom.xml b/container-dev/pom.xml index 16006452e61..f62bbd22690 100644 --- a/container-dev/pom.xml +++ b/container-dev/pom.xml @@ -121,18 +121,6 @@ <groupId>org.bouncycastle</groupId> <artifactId>bcprov-jdk15on</artifactId> </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> - </exclusion> - <exclusion> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </exclusion> </exclusions> </dependency> <dependency> @@ -201,18 +189,6 @@ <groupId>xerces</groupId> <artifactId>xercesImpl</artifactId> </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> - </exclusion> - <exclusion> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </exclusion> </exclusions> </dependency> <dependency> diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java index 1e44a8fa64d..fc1bbace092 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java @@ -1,16 +1,16 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.prelude.fastsearch; -import com.yahoo.container.search.LegacyEmulationConfig; -import com.yahoo.data.access.Inspector; -import com.yahoo.log.LogLevel; - import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Map; import java.util.logging.Logger; +import com.yahoo.data.access.Inspector; +import com.yahoo.container.search.LegacyEmulationConfig; + +import com.yahoo.log.LogLevel; /** * @author Bjørn Borud @@ -25,7 +25,7 @@ public abstract class DocsumField { Map<String, Constructor<? extends DocsumField>> constructors = new HashMap<>(); - void put(String typename, Class<? extends DocsumField> fieldClass) + void put(String typename, Class<? extends DocsumField> fieldClass) throws NoSuchMethodException, SecurityException { Constructor<? extends DocsumField> constructor = fieldClass.getConstructor(String.class); constructors.put(typename, constructor); @@ -106,7 +106,7 @@ public abstract class DocsumField { public abstract Object decode(ByteBuffer b); /** - * Get the number of bytes this field occupies in the given buffer + * Get the number of bytes this field occupies in the given buffer * AND SET(!) the position to the first byte after this field. */ public abstract int getLength(ByteBuffer b); diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java index 692e93bed7e..1524a4da426 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java @@ -109,7 +109,7 @@ public class FastHit extends Hit { /** * Returns the explicitly set uri if available, returns "index:[source]/[partid]/[id]" otherwise - * + * * @return uri of hit */ public URI getUri() { @@ -128,9 +128,9 @@ public class FastHit extends Hit { } /** - * The uri of the index location of this hit ("index:[source]/[partid]/[id]"). + * The uri of the index location of this hit ("index:[source]/[partid]/[id]"). * This is the uri if no other uri is assigned - * + * * @return uri to the index. */ public URI getIndexUri() { @@ -215,7 +215,7 @@ public class FastHit extends Hit { * The empty string ("") if no value is assigned in the document. * * <li><b>Dynamic summary string fields</b>: A Java String before JuniperSearcher and a HitField after.</li> - * + * * <li><b>Numerics</b>: The corresponding numeric Java type.<br> * If the field has <i>no value</i> assigned in the document, * the special numeric {@link com.yahoo.search.result.NanNumber#NaN} is returned. diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java index d8b38667224..e0ca7fbe6e1 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java @@ -13,7 +13,7 @@ import java.util.Optional; /** * A tensor field. Tensors are encoded as a data field where the data (following the length) * is encoded in a tensor binary format defined by com.yahoo.tensor.serialization.TypedBinaryFormat - * + * * @author bratseth */ public class TensorField extends DocsumField implements VariableLengthField { diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index 0fd529bf262..0ec15b95b0d 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -16,7 +16,7 @@ import java.util.Optional; public class TensorFieldType extends FieldType { // TODO: Require tensor type - + private final Optional<TensorType> type; /** Creates a tensor field type with optional information about the kind of tensor this will hold */ diff --git a/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java b/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java index 5494d1965f8..15a0fd60511 100644 --- a/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java @@ -102,7 +102,7 @@ public class SlimeSummaryTestCase { public void testDecoding() { Tensor tensor1 = Tensor.from("tensor(x{},y{}):{{x:foo,y:bar}:0.1}"); Tensor tensor2 = Tensor.from("tensor(x[],y[1]):{{x:0,y:0}:-0.3}"); - + String summary_cf = "file:src/test/java/com/yahoo/prelude/fastsearch/summary.cfg"; DocsumDefinitionSet set = createDocsumDefinitionSet(summary_cf); byte[] docsum = makeDocsum(tensor1, tensor2); diff --git a/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java b/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java index 62eacaa0afe..e59c03b33c3 100644 --- a/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java @@ -2,6 +2,8 @@ package com.yahoo.search.test; import com.yahoo.component.chain.Chain; +import com.yahoo.language.Language; +import com.yahoo.language.Linguistics; import com.yahoo.language.detect.Detection; import com.yahoo.language.detect.Detector; import com.yahoo.language.detect.Hint; @@ -26,6 +28,7 @@ import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.search.result.Hit; import com.yahoo.search.searchchain.Execution; + import com.yahoo.yolean.Exceptions; import org.junit.Ignore; import org.junit.Test; @@ -42,14 +45,14 @@ import java.util.stream.Collectors; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.is; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertNotSame; import static org.junit.Assert.fail; /** @@ -66,7 +69,7 @@ public class QueryTestCase { assertEquals("", q.properties().get("aParameter")); assertNull(q.properties().get("notSetParameter")); } - + // TODO: YQL work in progress (jon) @Ignore @Test @@ -690,7 +693,7 @@ public class QueryTestCase { List<IndexedItem> l = QueryTree.getPositiveTerms(i); assertEquals(3, l.size()); } - + @Test public void testHeuristicLanguageDetectionTextExtraction() { assertDetectionText("b ", "a:b", "text:a", "text:default"); @@ -717,27 +720,27 @@ public class QueryTestCase { q.getModel().getQueryTree(); // cause parsing assertEquals(expectedDetectionText, mockLinguistics.detector.lastDetectionText); } - + /** A linguistics instance which records the last language detection text passed to it */ private static class MockLinguistics extends SimpleLinguistics { final MockDetector detector = new MockDetector(); - + @Override public Detector getDetector() { return detector; } - + } - + private static class MockDetector extends SimpleDetector { String lastDetectionText = null; - + @Override public Detection detect(String input, Hint hint) { lastDetectionText = input; return super.detect(input, hint); } - + } protected boolean contains(String lineSubstring,String[] lines) { diff --git a/container/pom.xml b/container/pom.xml index d252a5eee4a..3793a3508a4 100644 --- a/container/pom.xml +++ b/container/pom.xml @@ -47,18 +47,6 @@ <groupId>org.apache.commons</groupId> <artifactId>commons-lang3</artifactId> </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> - </exclusion> - <exclusion> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </exclusion> </exclusions> </dependency> </dependencies> diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java index fb675862320..1b2ad9f938a 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java @@ -24,7 +24,6 @@ import java.util.Objects; * @author smorgrav */ public class ClusterCost { - private final double tco; private final double waste; private final ClusterInfo clusterInfo; @@ -33,8 +32,8 @@ public class ClusterCost { private final ClusterUtilization resultUtilization; /** - * @param clusterInfo value object with cluster info e.g. the TCO for the hardware used - * @param systemUtilization utilization of system resources (as ratios) + * @param clusterInfo Value object with cluster info e.g. the TCO for the hardware used + * @param systemUtilization Utilization of system resources (as ratios) */ public ClusterCost(ClusterInfo clusterInfo, ClusterUtilization systemUtilization) { @@ -80,10 +79,10 @@ public class ClusterCost { } static ClusterUtilization calculateResultUtilization(ClusterUtilization system, ClusterUtilization target) { - double cpu = ratio(system.getCpu(), target.getCpu()); - double mem = ratio(system.getMemory(), target.getMemory()); - double disk = ratio(system.getDisk(), target.getDisk()); - double diskbusy = ratio(system.getDiskBusy(), target.getDiskBusy()); + double cpu = ratio(system.getCpu(),target.getCpu()); + double mem = ratio(system.getMemory(),target.getMemory()); + double disk = ratio(system.getDisk(),target.getDisk()); + double diskbusy = ratio(system.getDiskBusy(),target.getDiskBusy()); return new ClusterUtilization(mem, cpu, disk, diskbusy); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java index 371e1c41e32..585690793bb 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java @@ -44,17 +44,17 @@ public class DeploymentCost { return clusters; } - /** Returns the total monthly cost of ownership for the deployment (sum of all clusters) */ + /** @return Total cost of ownership for the deployment (sum of all clusters) */ public double getTco() { return tco; } - /** Returns the utilization of clusters that wastes most money in this deployment */ + /** @return The utilization of clusters that wastes most money in this deployment */ public double getUtilization() { return utilization; } - /** Returns the amount of dollars spent and not utilized */ + /** @return The amount of dollars spent and not utilized */ public double getWaste() { return waste; } diff --git a/document/src/main/java/com/yahoo/document/DataType.java b/document/src/main/java/com/yahoo/document/DataType.java index abdbf394591..c8a04866aa9 100644 --- a/document/src/main/java/com/yahoo/document/DataType.java +++ b/document/src/main/java/com/yahoo/document/DataType.java @@ -51,7 +51,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com public final static PrimitiveDataType URI = new PrimitiveDataType("uri", 10, UriFieldValue.class, new UriFieldValue.Factory()); public final static NumericDataType BYTE = new NumericDataType("byte", 16, ByteFieldValue.class, ByteFieldValue.getFactory()); public final static PrimitiveDataType PREDICATE = new PrimitiveDataType("predicate", 20, PredicateFieldValue.class, PredicateFieldValue.getFactory()); - public final static int tensorDataTypeCode = 21; // All TensorDataType instances have id=21 but carries additional type information serialized separately + public final static int tensorDataTypeCode = 21; // All TensorDataType instances have id=21 but carries additional type information serialized separately // ADDITIONAL parametrized types added at runtime: map, struct, array, weighted set, annotation reference, tensor // Tags are converted to weightedset<string> when reading the search definition TODO: Remove it @@ -99,7 +99,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com /** * Creates a field value by reflection - * + * * @param arg the value of the newly created field value * @return a fully constructed value */ @@ -201,7 +201,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com public static TensorDataType getTensor(TensorType type) { return new TensorDataType(type); } - + public String getName() { return name; } @@ -267,7 +267,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com */ public FieldPath buildFieldPath(String fieldPathString) { if (fieldPathString.length() > 0) { - throw new IllegalArgumentException("Datatype " + toString() + + throw new IllegalArgumentException("Datatype " + toString() + " does not support further recursive structure: " + fieldPathString); } return new FieldPath(); diff --git a/document/src/main/java/com/yahoo/document/DocumentTypeManager.java b/document/src/main/java/com/yahoo/document/DocumentTypeManager.java index 5fad35a2287..8c9318199d8 100644 --- a/document/src/main/java/com/yahoo/document/DocumentTypeManager.java +++ b/document/src/main/java/com/yahoo/document/DocumentTypeManager.java @@ -38,7 +38,7 @@ public class DocumentTypeManager { // *Configured data types* (not built-in/primitive) indexed by their id // // *Primitive* data types are always available and have a single id. - // + // // *Built-in dynamic* types: The tensor type. // Any tensor type has the same id and is always available just like primitive types. // However, unlike primitive types, each tensor type is a separate DataType instance @@ -112,7 +112,7 @@ public class DocumentTypeManager { public DataType getDataType(String name) { if (name.startsWith("tensor(")) // built-in dynamic return new TensorDataType(TensorType.fromSpec(name)); - + List<DataType> foundTypes = new ArrayList<>(); for (DataType type : dataTypes.values()) { if (type.getName().equalsIgnoreCase(name)) { @@ -141,10 +141,10 @@ public class DocumentTypeManager { } public DataType getDataType(int code) { return getDataType(code, ""); } - + /** * Return a data type instance - * + * * @param code the code of the data type to return, which must be either built in or present in this manager * @param detailedType detailed type information, or the empty string if none * @return the appropriate DataType instance @@ -183,7 +183,7 @@ public class DocumentTypeManager { /** * Register a single datatype. Re-registering an existing, but equal, datatype is ok. - * + * * @param type The datatype to register */ void registerSingleType(DataType type) { @@ -280,7 +280,7 @@ public class DocumentTypeManager { /** * Returns a read only view of the registered data types - * + * * @return collection of types */ public Collection<DataType> getDataTypes() { diff --git a/document/src/main/java/com/yahoo/document/TensorDataType.java b/document/src/main/java/com/yahoo/document/TensorDataType.java index 50e9cf0f60f..aefdc030a12 100644 --- a/document/src/main/java/com/yahoo/document/TensorDataType.java +++ b/document/src/main/java/com/yahoo/document/TensorDataType.java @@ -8,13 +8,13 @@ import com.yahoo.vespa.objects.Ids; /** * A DataType containing a tensor type - * + * * @author bratseth */ public class TensorDataType extends DataType { private final TensorType tensorType; - + // The global class identifier shared with C++. public static int classId = registerClass(Ids.document + 59, TensorDataType.class); @@ -47,5 +47,5 @@ public class TensorDataType extends DataType { /** Returns the type of the tensor this field can hold */ public TensorType getTensorType() { return tensorType; } - + } diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java index 1808396986e..ae8d5cf596a 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java +++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java @@ -19,7 +19,7 @@ import java.util.Optional; public class TensorFieldValue extends FieldValue { private Optional<Tensor> tensor; - + private final TensorDataType dataType; /** Create an empty tensor field value */ @@ -66,7 +66,7 @@ public class TensorFieldValue extends FieldValue { o.getClass().getName() + "'."); } } - + public void assignTensor(Optional<Tensor> tensor) { if (tensor.isPresent() && ! tensor.get().type().isAssignableTo(dataType.getTensorType())) throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() + diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 29ba244a9f1..f37fa5ea675 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -146,9 +146,9 @@ public class JsonReaderTestCase { } { DocumentType x = new DocumentType("testtensor"); - x.addField(new Field("mappedtensorfield", + x.addField(new Field("mappedtensorfield", new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").build()))); - x.addField(new Field("indexedtensorfield", + x.addField(new Field("indexedtensorfield", new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build()))); types.registerDocumentType(x); } @@ -1280,8 +1280,8 @@ public class JsonReaderTestCase { return (DocumentPut) reader.next(); } - private DocumentPut createPutWithMappedTensor(String inputTensor) { - return createPutWithTensor(inputTensor, "mappedtensorfield"); + private DocumentPut createPutWithMappedTensor(String inputTensor) { + return createPutWithTensor(inputTensor, "mappedtensorfield"); } private DocumentPut createPutWithTensor(String inputTensor, String tensorFieldName) { InputStream rawDoc = new ByteArrayInputStream( diff --git a/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java b/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java index 5c65b11a0c4..7104c1686f8 100644 --- a/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java +++ b/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java @@ -24,7 +24,7 @@ public class TensorFieldValueSerializationTestCase { private final static TensorType tensorType = new TensorType.Builder().mapped("dimX").mapped("dimY").build(); private final static String TENSOR_FIELD = "my_tensor"; private final static String TENSOR_FILES = "src/test/resources/tensor/"; - private final static TestDocumentFactory docFactory = new TestDocumentFactory(createDocType(), + private final static TestDocumentFactory docFactory = new TestDocumentFactory(createDocType(), "id:test:my_type::foo"); private static DocumentType createDocType() { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java index 9ef1a3f6e32..0f08bf0bf21 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java @@ -56,7 +56,7 @@ public abstract class FieldUpdateHelper { } else if (upd instanceof ArithmeticValueUpdate) { if (((ArithmeticValueUpdate)upd).getOperator() == ArithmeticValueUpdate.Operator.DIV && ((ArithmeticValueUpdate)upd).getOperand().doubleValue() == 0) { - throw new IllegalArgumentException("Div by zero."); + throw new IllegalArgumentException("Division by zero."); } val.assign(upd.getValue()); return val; @@ -23,941 +23,6 @@ </developer> </developers> - <distributionManagement> - <repository> - <id>bintray-vespa-repo</id> - <url>https://api.bintray.com/maven/yahoo/maven/vespa;publish=1</url> - </repository> - </distributionManagement> - - <repositories> - <!-- Required for Athenz libraries --> - <repository> - <snapshots> - <enabled>false</enabled> - </snapshots> - <id>bintray-yahoo-maven</id> - <name>bintray</name> - <url>https://yahoo.bintray.com/maven</url> - </repository> - </repositories> - - <scm> - <connection>scm:git:git@github.com:vespa-engine/vespa.git</connection> - <developerConnection>scm:git:git@github.com:vespa-engine/vespa.git</developerConnection> - <url>git@github.com:vespa-engine/vespa.git</url> - </scm> - - <build> - <finalName>${project.artifactId}</finalName> - <extensions> - <extension> - <groupId>org.apache.maven.wagon</groupId> - <artifactId>wagon-ssh-external</artifactId> - <version>2.7</version> - </extension> - <extension> - <groupId>org.apache.maven.archetype</groupId> - <artifactId>archetype-packaging</artifactId> - <version>2.0</version> - </extension> - </extensions> - <pluginManagement> - <plugins> - <plugin> - <groupId>org.antlr</groupId> - <artifactId>antlr3-maven-plugin</artifactId> - <version>${antlr.version}</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-antrun-plugin</artifactId> - <version>1.7</version> - </plugin> - <plugin> - <groupId>org.apache.felix</groupId> - <artifactId>maven-bundle-plugin</artifactId> - <version>2.4.0</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-assembly-plugin</artifactId> - <version>2.4</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-compiler-plugin</artifactId> - <version>3.6.1</version> - <configuration> - <source>1.8</source> - <target>1.8</target> - <showWarnings>true</showWarnings> - <optimize>true</optimize> - <showDeprecation>false</showDeprecation> - <compilerArgs> - <arg>-Xlint:all</arg> - <arg>-Xlint:-serial</arg> - <arg>-Xlint:-try</arg> - <arg>-Xlint:-processing</arg> - <arg>-Xlint:-varargs</arg> - <arg>-Werror</arg> - </compilerArgs> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-dependency-plugin</artifactId> - <version>2.10</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-deploy-plugin</artifactId> - <version>2.5</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-install-plugin</artifactId> - <version>2.5.2</version> - <configuration> - <updateReleaseInfo>true</updateReleaseInfo> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-jar-plugin</artifactId> - <version>3.0.2</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-javadoc-plugin</artifactId> - <configuration> - <additionalparam>-Xdoclint:${doclint} -Xdoclint:-missing</additionalparam> - </configuration> - <version>2.10.4</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-plugin-plugin</artifactId> - <version>3.5</version> - <configuration> - <!-- see http://jira.codehaus.org/browse/MNG-5346 --> - <skipErrorNoDescriptorsFound>true</skipErrorNoDescriptorsFound> - </configuration> - <executions> - <execution> - <id>mojo-descriptor</id> - <goals> - <goal>descriptor</goal> - </goals> - </execution> - </executions> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-resources-plugin</artifactId> - <version>2.7</version> - <configuration> - <escapeString>\</escapeString> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-site-plugin</artifactId> - <version>3.3</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-source-plugin</artifactId> - <version>2.1.2</version> - <configuration> - <includePom>true</includePom> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-surefire-plugin</artifactId> - <version>${surefire.version}</version> - <configuration> - <redirectTestOutputToFile>${test.hide}</redirectTestOutputToFile> - <systemPropertyVariables> - <java.io.tmpdir>${project.build.directory}</java.io.tmpdir> - </systemPropertyVariables> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-surefire-report-plugin</artifactId> - <version>${surefire.version}</version> - <configuration> - <alwaysGenerateSurefireReport>false</alwaysGenerateSurefireReport> - <showSuccess>false</showSuccess> - </configuration> - </plugin> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>build-helper-maven-plugin</artifactId> - <version>1.9.1</version> - </plugin> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>exec-maven-plugin</artifactId> - <version>1.6.0</version> - </plugin> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>javacc-maven-plugin</artifactId> - <version>2.6</version> - </plugin> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>properties-maven-plugin</artifactId> - <version>1.0.0</version> - </plugin> - <plugin> - <groupId>net.alchim31.maven</groupId> - <artifactId>scala-maven-plugin</artifactId> - <version>3.2.2</version> - <configuration> - <args> - <arg>-unchecked</arg> - <arg>-deprecation</arg> - <arg>-feature</arg> - <arg>-Xfatal-warnings</arg> - </args> - </configuration> - </plugin> - <plugin> - <groupId>com.yahoo.vespa</groupId> - <artifactId>bundle-plugin</artifactId> - <version>${project.version}</version> - <configuration> - <configGenVersion>${project.version}</configGenVersion> - <useCommonAssemblyIds>true</useCommonAssemblyIds> - </configuration> - </plugin> - </plugins> - </pluginManagement> - </build> - <profiles> - <profile> - <id>attach-sources</id> - <activation> - <property> - <name>!skipSources</name> - </property> - </activation> - <build> - <plugins> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-source-plugin</artifactId> - <executions> - <execution> - <id>attach-sources</id> - <goals> - <goal>jar-no-fork</goal> - </goals> - </execution> - </executions> - </plugin> - </plugins> - </build> - </profile> - <profile> - <id>generate-javadoc</id> - <activation> - <property> - <name>!skipJavadoc</name> - </property> - </activation> - <build> - <plugins> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-javadoc-plugin</artifactId> - <executions> - <execution> - <id>generate-javadoc</id> - <phase>package</phase> - <goals> - <goal>javadoc</goal> - </goals> - </execution> - </executions> - <configuration> - <additionalparam>-Xdoclint:${doclint} -Xdoclint:-missing</additionalparam> - <failOnError>${javadoc.failOnError}</failOnError> - <quiet>true</quiet> - <show>private</show> - </configuration> - </plugin> - </plugins> - </build> - </profile> - <profile> - <id>coverage</id> - <build> - <plugins> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>exec-maven-plugin</artifactId> - <configuration> - <includePluginDependencies>true</includePluginDependencies> - </configuration> - </plugin> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>build-helper-maven-plugin</artifactId> - <executions> - <execution> - <phase>generate-sources</phase> - <goals> - <goal>add-source</goal> - </goals> - <configuration> - <sources> - <source>src/main/scala</source> - </sources> - </configuration> - </execution> - <execution> - <id>add-test-source</id> - <phase>generate-test-sources</phase> - <goals> - <goal>add-test-source</goal> - </goals> - <configuration> - <sources> - <source>src/test/scala</source> - </sources> - </configuration> - </execution> - </executions> - </plugin> - </plugins> - </build> - </profile> - <profile> - <id>sign-artifacts</id> - <build> - <plugins> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-gpg-plugin</artifactId> - <version>1.6</version> - <executions> - <execution> - <id>sign-artifacts</id> - <phase>verify</phase> - <goals> - <goal>sign</goal> - </goals> - </execution> - </executions> - </plugin> - </plugins> - </build> - </profile> - </profiles> - <dependencyManagement> - <dependencies> - <dependency> - <groupId>org.apache.maven.wagon</groupId> - <artifactId>wagon-ssh-external</artifactId> - <version>2.7</version> - </dependency> - <dependency> - <groupId>com.github.cverges.expect4j</groupId> - <artifactId>expect4j</artifactId> - <version>1.6</version> - </dependency> - <dependency> - <groupId>org.apache.commons</groupId> - <artifactId>commons-compress</artifactId> - <version>1.11</version> - </dependency> - <dependency> - <groupId>org.apache.commons</groupId> - <artifactId>commons-exec</artifactId> - <version>1.3</version> - </dependency> - <dependency> - <groupId>io.airlift</groupId> - <artifactId>airline</artifactId> - <version>0.7</version> - </dependency> - <dependency> - <groupId>aopalliance</groupId> - <artifactId>aopalliance</artifactId> - <version>1.0</version> - </dependency> - <dependency> - <groupId>org.ow2.asm</groupId> - <artifactId>asm</artifactId> - <version>5.2</version> - </dependency> - <dependency> - <groupId>com.google.code.findbugs</groupId> - <artifactId>annotations</artifactId> - <version>1.3.9</version> - </dependency> - <dependency> - <groupId>com.google.code.findbugs</groupId> - <artifactId>jsr305</artifactId> - <version>1.3.9</version> - </dependency> - <dependency> - <groupId>com.google.guava</groupId> - <artifactId>guava</artifactId> - <version>18.0</version> - </dependency> - <dependency> - <groupId>com.google.guava</groupId> - <artifactId>guava-testlib</artifactId> - <version>18.0</version> - </dependency> - <dependency> - <groupId>com.google.inject</groupId> - <artifactId>guice</artifactId> - <version>3.0</version> - </dependency> - <dependency> - <groupId>com.google.inject</groupId> - <artifactId>guice</artifactId> - <version>3.0</version> - <classifier>no_aop</classifier> - </dependency> - <dependency> - <groupId>com.google.inject.extensions</groupId> - <artifactId>guice-assistedinject</artifactId> - <version>3.0</version> - </dependency> - <dependency> - <groupId>com.google.inject.extensions</groupId> - <artifactId>guice-multibindings</artifactId> - <version>3.0</version> - </dependency> - <dependency> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - <version>3.4.0</version> - </dependency> - <dependency> - <groupId>com.googlecode.jmockit</groupId> - <artifactId>jmockit</artifactId> - <version>1.2</version> - </dependency> - <dependency> - <groupId>com.goldmansachs</groupId> - <artifactId>gs-collections</artifactId> - <version>6.1.0</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.core</groupId> - <artifactId>jackson-core</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.core</groupId> - <artifactId>jackson-databind</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.core</groupId> - <artifactId>jackson-annotations</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.jaxrs</groupId> - <artifactId>jackson-jaxrs-json-provider</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.module</groupId> - <artifactId>jackson-module-jaxb-annotations</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.jaxrs</groupId> - <artifactId>jackson-jaxrs-base</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.jaxrs</groupId> - <artifactId>jackson-jaxrs-xml-provider</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.dataformat</groupId> - <artifactId>jackson-dataformat-xml</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.datatype</groupId> - <artifactId>jackson-datatype-jdk8</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.datatype</groupId> - <artifactId>jackson-datatype-jsr310</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.infradna.tool</groupId> - <artifactId>bridge-method-annotation</artifactId> - <version>1.4</version> - </dependency> - <dependency> - <groupId>commons-cli</groupId> - <artifactId>commons-cli</artifactId> - <version>1.3.1</version> - </dependency> - <dependency> - <groupId>commons-codec</groupId> - <artifactId>commons-codec</artifactId> - <version>1.4</version> - </dependency> - <dependency> - <groupId>commons-collections</groupId> - <artifactId>commons-collections</artifactId> - <version>3.2.1</version> - </dependency> - <dependency> - <groupId>commons-configuration</groupId> - <artifactId>commons-configuration</artifactId> - <version>1.6</version> - </dependency> - <dependency> - <groupId>commons-daemon</groupId> - <artifactId>commons-daemon</artifactId> - <version>1.0.3</version> - </dependency> - <dependency> - <groupId>commons-io</groupId> - <artifactId>commons-io</artifactId> - <version>2.4</version> - </dependency> - <dependency> - <groupId>commons-lang</groupId> - <artifactId>commons-lang</artifactId> - <version>${commons-lang.version}</version> - </dependency> - <dependency> - <!-- This version is exported by jdisc via jcl-over-slf4j. --> - <groupId>commons-logging</groupId> - <artifactId>commons-logging</artifactId> - <version>1.1.1</version> - </dependency> - <dependency> - <groupId>commons-net</groupId> - <artifactId>commons-net</artifactId> - <version>2.0</version> - </dependency> - <dependency> - <groupId>commons-pool</groupId> - <artifactId>commons-pool</artifactId> - <version>1.5.6</version> - </dependency> - <!-- Explicitly included to get Zookeeper version 3.4.10, - can be excluded if you want the Zookeeper version - used by curator by default - --> - <dependency> - <groupId>org.apache.zookeeper</groupId> - <artifactId>zookeeper</artifactId> - <version>3.4.10</version> - </dependency> - <dependency> - <groupId>org.apache.curator</groupId> - <artifactId>curator-recipes</artifactId> - <version>${curator.version}</version> - </dependency> - <dependency> - <groupId>org.apache.curator</groupId> - <artifactId>curator-test</artifactId> - <version>${curator.version}</version> - </dependency> - <dependency> - <groupId>javax.servlet</groupId> - <artifactId>javax.servlet-api</artifactId> - <version>3.1.0</version> - </dependency> - <dependency> - <groupId>junit</groupId> - <artifactId>junit</artifactId> - <version>4.12</version> - </dependency> - <dependency> - <groupId>org.antlr</groupId> - <artifactId>antlr-runtime</artifactId> - <version>${antlr.version}</version> - </dependency> - <dependency> - <groupId>org.antlr</groupId> - <artifactId>antlr4-runtime</artifactId> - <version>${antlr4.version}</version> - </dependency> - <dependency> - <groupId>org.apache.aries.spifly</groupId> - <artifactId>org.apache.aries.spifly.dynamic.bundle</artifactId> - <version>${aries.spifly.version}</version> - </dependency> - <dependency> - <groupId>org.apache.commons</groupId> - <artifactId>commons-lang3</artifactId> - <version>3.1</version> - </dependency> - <dependency> - <groupId>org.apache.felix</groupId> - <artifactId>org.apache.felix.framework</artifactId> - <version>4.2.1</version> - </dependency> - <dependency> - <groupId>org.apache.felix</groupId> - <artifactId>org.apache.felix.log</artifactId> - <version>1.0.1</version> - </dependency> - <dependency> - <groupId>org.apache.felix</groupId> - <artifactId>org.apache.felix.main</artifactId> - <version>4.2.1</version> - </dependency> - <dependency> - <groupId>org.apache.httpcomponents</groupId> - <artifactId>fluent-hc</artifactId> - <version>4.3.6</version> - </dependency> - <dependency> - <groupId>org.apache.httpcomponents</groupId> - <artifactId>httpclient</artifactId> - <version>4.3.6</version> - </dependency> - <dependency> - <groupId>org.apache.httpcomponents</groupId> - <artifactId>httpcore</artifactId> - <version>4.3.3</version> - </dependency> - <dependency> - <groupId>org.apache.httpcomponents</groupId> - <artifactId>httpmime</artifactId> - <version>4.3.6</version> - </dependency> - <dependency> - <groupId>org.apache.maven</groupId> - <artifactId>maven-artifact</artifactId> - <version>3.5.0</version> - </dependency> - <dependency> - <groupId>org.apache.maven</groupId> - <artifactId>maven-core</artifactId> - <version>3.5.0</version> - </dependency> - <dependency> - <groupId>org.apache.maven</groupId> - <artifactId>maven-model</artifactId> - <version>3.5.0</version> - </dependency> - <dependency> - <groupId>org.apache.maven.plugin-tools</groupId> - <artifactId>maven-plugin-annotations</artifactId> - <version>3.5</version> - </dependency> - <dependency> - <groupId>org.apache.maven</groupId> - <artifactId>maven-plugin-api</artifactId> - <version>3.5.0</version> - </dependency> - <dependency> - <groupId>org.apache.maven</groupId> - <artifactId>maven-project</artifactId> - <version>2.2.1</version> - </dependency> - <dependency> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-jar-plugin</artifactId> - <version>3.0.2</version> - </dependency> - <dependency> - <groupId>org.apache.maven.surefire</groupId> - <artifactId>surefire-junit4</artifactId> - <version>${surefire.version}</version> - </dependency> - <dependency> - <groupId>org.apache.maven.surefire</groupId> - <artifactId>surefire-providers</artifactId> - <version>${surefire.version}</version> - <type>pom</type> - </dependency> - <dependency> - <groupId>org.codehaus.jettison</groupId> - <artifactId>jettison</artifactId> - <version>1.3.1</version> - </dependency> - <dependency> - <groupId>org.cthul</groupId> - <artifactId>cthul-matchers</artifactId> - <version>1.0</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-continuation</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-server</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-servlet</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-servlets</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-util</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-http</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-jmx</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.hamcrest</groupId> - <artifactId>hamcrest-all</artifactId> - <version>1.3</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.hamcrest</groupId> - <artifactId>hamcrest-core</artifactId> - <version>1.3</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.hamcrest</groupId> - <artifactId>hamcrest-library</artifactId> - <version>1.3</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>uk.co.datumedge</groupId> - <artifactId>hamcrest-json</artifactId> - <version>0.2</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.hdrhistogram</groupId> - <artifactId>HdrHistogram</artifactId> - <version>2.1.8</version> - </dependency> - <dependency> - <groupId>org.json</groupId> - <artifactId>json</artifactId> - <version>20090211</version> - </dependency> - <dependency> - <groupId>org.mockito</groupId> - <artifactId>mockito-all</artifactId> - <version>1.9.5</version> - </dependency> - <dependency> - <groupId>org.mockito</groupId> - <artifactId>mockito-core</artifactId> - <version>1.9.5</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.osgi</groupId> - <artifactId>org.osgi.compendium</artifactId> - <version>4.3.0</version> - </dependency> - <dependency> - <groupId>org.osgi</groupId> - <artifactId>org.osgi.core</artifactId> - <version>4.3.0</version> - </dependency> - <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>scala-library</artifactId> - <version>${scala.version}</version> - </dependency> - <dependency> - <groupId>org.scala-lang.modules</groupId> - <artifactId>scala-parser-combinators_${scala.major-version}</artifactId> - <version>1.0.1</version> - </dependency> - <dependency> - <groupId>org.scala-lang.modules</groupId> - <artifactId>scala-xml_${scala.major-version}</artifactId> - <version>1.0.2</version> - </dependency> - <dependency> - <groupId>org.scalatest</groupId> - <artifactId>scalatest_${scala.major-version}</artifactId> - <version>2.2.2</version> - </dependency> - <dependency> - <groupId>org.slf4j</groupId> - <artifactId>jcl-over-slf4j</artifactId> - <version>1.7.5</version> - </dependency> - <dependency> - <groupId>org.slf4j</groupId> - <artifactId>log4j-over-slf4j</artifactId> - <version>1.7.5</version> - </dependency> - <dependency> - <groupId>org.slf4j</groupId> - <artifactId>slf4j-api</artifactId> - <version>1.7.5</version> - </dependency> - <dependency> - <groupId>org.slf4j</groupId> - <artifactId>slf4j-jdk14</artifactId> - <version>1.7.5</version> - </dependency> - <dependency> - <groupId>org.springframework</groupId> - <artifactId>spring-test</artifactId> - <version>4.0.6.RELEASE</version> - </dependency> - <dependency> - <groupId>org.testng</groupId> - <artifactId>testng</artifactId> - <version>6.10</version> - </dependency> - <dependency> - <groupId>org.twdata.maven</groupId> - <artifactId>mojo-executor</artifactId> - <version>2.3.0</version> - </dependency> - <dependency> - <groupId>net.jcip</groupId> - <artifactId>jcip-annotations</artifactId> - <version>1.0</version> - </dependency> - <dependency> - <groupId>net.jpountz.lz4</groupId> - <artifactId>lz4</artifactId> - <version>1.3.0</version> - </dependency> - <dependency> - <groupId>net.spy</groupId> - <artifactId>spymemcached</artifactId> - <version>2.10.1</version> - </dependency> - <dependency> - <groupId>xerces</groupId> - <artifactId>xercesImpl</artifactId> - <version>2.11.0</version> - </dependency> - <dependency> - <groupId>org.bouncycastle</groupId> - <artifactId>bcpkix-jdk15on</artifactId> - <version>${bouncycastle.version}</version> - </dependency> - <dependency> - <groupId>org.bouncycastle</groupId> - <artifactId>bcprov-jdk15on</artifactId> - <version>${bouncycastle.version}</version> - </dependency> - <!-- jersey 2 support --> - <dependency> - <groupId>javax.ws.rs</groupId> - <artifactId>javax.ws.rs-api</artifactId> - <version>${javax.ws.rs-api.version}</version> - </dependency> - <dependency> - <groupId>org.glassfish.jersey.containers</groupId> - <artifactId>jersey-container-servlet-core</artifactId> - <version>${jersey2.version}</version> - </dependency> - <dependency> - <groupId>org.glassfish.jersey.containers</groupId> - <artifactId>jersey-container-servlet</artifactId> - <version>${jersey2.version}</version> - </dependency> - <dependency> - <groupId>org.glassfish.jersey.media</groupId> - <artifactId>jersey-media-json-jackson</artifactId> - <version>${jersey2.version}</version> - </dependency> - <dependency> - <groupId>org.glassfish.jersey.media</groupId> - <artifactId>jersey-media-multipart</artifactId> - <version>${jersey2.version}</version> - </dependency> - <dependency> - <groupId>org.glassfish.jersey.ext</groupId> - <artifactId>jersey-proxy-client</artifactId> - <version>${jersey2.version}</version> - </dependency> - <dependency> - <groupId>org.glassfish.jersey.core</groupId> - <artifactId>jersey-client</artifactId> - <version>${jersey2.version}</version> - </dependency> - <dependency> - <groupId>com.ibm.icu</groupId> - <artifactId>icu4j</artifactId> - <version>57.1</version> - </dependency> - <dependency> - <groupId>com.yahoo.athenz</groupId> - <artifactId>athenz-zms-java-client</artifactId> - <version>${athenz.version}</version> - </dependency> - <dependency> - <groupId>com.yahoo.athenz</groupId> - <artifactId>athenz-zts-java-client</artifactId> - <version>${athenz.version}</version> - </dependency> - </dependencies> - </dependencyManagement> - - <properties> - <javax.ws.rs-api.version>2.0.1</javax.ws.rs-api.version> <!-- must be kept in sync with version used by current jersey2.version --> - <antlr.version>3.5.2</antlr.version> - <antlr4.version>4.5</antlr4.version> - <aries.spifly.version>1.0.8</aries.spifly.version> - <aries.util.version>1.0.0</aries.util.version> - <asm-debug-all.version>5.0.3</asm-debug-all.version> - <!-- Athenz dependencies. Make sure these dependencies matches those in Vespa's internal repositories --> - <athenz.version>1.7.28</athenz.version> - <bouncycastle.version>1.58</bouncycastle.version> - <commons-lang.version>2.6</commons-lang.version> - <!-- WARNING: If you change curator version, you also need to update - zkfacade/src/main/java/org/apache/curator/**/package-info.java - using something like - find zkfacade/src/main/java/org/apache/curator -name package-info.java | \ - xargs perl -pi -e 's/major = [0-9]+, minor = [0-9]+, micro = [0-9]+/major = 2, minor = 9, micro = 1/g' - --> - <curator.version>2.9.1</curator.version> - <jackson2.version>2.8.3</jackson2.version> - <jersey2.version>2.23.2</jersey2.version> - <jetty.version>9.4.6.v20170531</jetty.version> - <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> - <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding> - <test.hide>true</test.hide> - <doclint>all</doclint> - <scala.major-version>2.11</scala.major-version> - <scala.version>${scala.major-version}.4</scala.version> - <surefire.version>2.19.1</surefire.version> <!-- NOTE bjorncs 15.06.2017: Version 2.20 has OoM issues --> - </properties> - <modules> <module>application</module> <module>application-deploy-plugin</module> diff --git a/searchlib/pom.xml b/searchlib/pom.xml index 09ccf9928b7..5f6717d9516 100644 --- a/searchlib/pom.xml +++ b/searchlib/pom.xml @@ -36,21 +36,6 @@ <version>${project.version}</version> </dependency> <dependency> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - <version>3.4.0</version> - </dependency> - <dependency> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - <version>1.4.0</version> - </dependency> - <dependency> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> - <version>1.4.0</version> - </dependency> - <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-core</artifactId> <scope>test</scope> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 0eeb0a9e630..785ed78492e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.tensor.Tensor; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Set; @@ -19,30 +18,26 @@ public abstract class Context implements EvaluationContext { /** * <p>Returns the value of a simple variable name.</p> * - * @param name the name of the variable whose value to return. - * @return the value of the named variable. + * @param name The name of the variable whose value to return. + * @return The value of the named variable. */ public abstract Value get(String name); - /** Returns a variable as a tensor */ - @Override - public Tensor getTensor(String name) { return get(name).asTensor(); } - /** * <p>Returns the value of a <i>structured variable</i> on the form * <code>name(argument*)(.output)?</code>, where <i>argument</i> is any * string. This may be used to implement more advanced variables whose * values are calculated at runtime from arguments. Supporting this in a - * context is optional. - * + * context is optional. + * * <p>This default implementation generates a name on the form * <code>name(argument1, argument2, ...argumentN).output</code>. * If there are no arguments the parenthesis are omitted. * If there is no output, the dot is omitted.</p> * - * @param name the name of this variable. - * @param arguments the parsed arguments as given in the textual expression. - * @param output the name of the value to output (to enable one named + * @param name The name of this variable. + * @param arguments The parsed arguments as given in the textual expression. + * @param output The name of the value to output (to enable one named * calculation to output several), or null to output the * "main" (or only) value. */ @@ -59,20 +54,20 @@ public abstract class Context implements EvaluationContext { * context subclasses. This default implementation throws * UnsupportedOperationException.</p> * - * @param index the index of the variable whose value to return. - * @return the value of the indexed variable. + * @param index The index of the variable whose value to return. + * @return The value of the indexed variable. */ public Value get(int index) { throw new UnsupportedOperationException(this + " does not support variable lookup by index"); } /** - * Lookup by index rather than name directly to a double. This is supported by some optimized + * <p>Lookup by index rather than name directly to a double. This is supported by some optimized * context subclasses. This default implementation throws - * UnsupportedOperationException. + * UnsupportedOperationException.</p> * - * @param index the index of the variable whose value to return. - * @return the value of the indexed variable. + * @param index The index of the variable whose value to return. + * @return The value of the indexed variable. */ public double getDouble(int index) { throw new UnsupportedOperationException(this + " does not support variable lookup by index"); @@ -86,23 +81,24 @@ public abstract class Context implements EvaluationContext { } /** - * Sets a value to this, or throws an UnsupportedOperationException if - * this is not supported. This default implementation does the latter. + * <p>Sets a value to this, or throws an UnsupportedOperationException if + * this is not supported. This default implementation does the latter.</p> * * - * @param name the name of the variable to set. + * @param name The name of the variable to set. * @param value the value to set. Ownership of this value is transferred to this - if it is mutable * (not frozen) it may be modified during execution + * @since 5.1.5 */ public void put(String name, Value value) { throw new UnsupportedOperationException(this + " does not support variable assignment"); } /** - * Returns all the names available in this, or throws an + * <p>Returns all the names available in this, or throws an * UnsupportedOperationException if this operation is not supported. This - * default implementation does the latter. + * default implementation does the latter.</p> * - * @return the set of all variable names. + * @return The set of all variable names. */ public Set<String> names() { throw new UnsupportedOperationException(this + " does not support return a list of its names"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index 2ef4a2ede2f..ea750295423 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -3,9 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; /** * A value which acts as a double in numerical context. @@ -19,11 +16,6 @@ public abstract class DoubleCompatibleValue extends Value { public boolean hasDouble() { return true; } @Override - public Tensor asTensor() { - return doubleAsTensor(asDouble()); - } - - @Override public Value negate() { return new DoubleValue(-asDouble()); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index dad69b31181..ac8aba6a617 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -4,14 +4,12 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; /** * A string value. * * @author bratseth + * @since 5.1.21 */ public class StringValue extends Value { @@ -37,11 +35,6 @@ public class StringValue extends Value { } @Override - public Tensor asTensor() { - return doubleAsTensor(asDouble()); - } - - @Override public boolean hasDouble() { return true; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 26c30fe5ed2..49c3ccb7b01 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -2,10 +2,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.google.common.annotations.Beta; -import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; +import com.yahoo.searchlib.rankingexpression.rule.Function; +import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; +import com.yahoo.tensor.TensorType; + +import java.util.Collections; +import java.util.Optional; /** * A Value containing a tensor. @@ -19,7 +23,7 @@ public class TensorValue extends Value { /** The tensor value of this */ private final Tensor value; - + public TensorValue(Tensor value) { this.value = value; } @@ -127,7 +131,7 @@ public class TensorValue extends Value { public Value compare(TruthOperator operator, Value argument) { return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString()))); } - + private Tensor compareTensor(TruthOperator operator, Tensor argument) { switch (operator) { case LARGER: return value.larger(argument); @@ -148,7 +152,7 @@ public class TensorValue extends Value { else return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble()))); } - + private Tensor functionOnTensor(Function function, Tensor argument) { switch (function) { case min: return value.min(argument); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 40d70e0022c..b2ccbe572d0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -5,8 +5,6 @@ import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; /** * The result of a ranking expression evaluation. @@ -27,14 +25,6 @@ public abstract class Value { return new DoubleValue(asDouble()); } - /** Returns this as a tensor value */ - public abstract Tensor asTensor(); - - /** A utility method for wrapping a sdouble in a rank 0 tensor */ - protected Tensor doubleAsTensor(double value) { - return Tensor.Builder.of(TensorType.empty).cell(TensorAddress.of(), value).build(); - } - /** Returns true if this value can return itself as a double, i.e asDoubleValue will return a value and not throw */ public abstract boolean hasDouble(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java deleted file mode 100644 index b4a9b363ade..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java +++ /dev/null @@ -1,51 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -/** - * The result of importing a TensorFlow model into Vespa: - * - A list of ranking expressions reproducing the computations of the outputs in the TensorFlow model - * - A list of named constant tensors - * - A list of expected input tensors, with their tensor type - * - A list of warning messages - * - * @author bratseth - */ -// This object can be built incrementally within this package, but is immutable when observed from outside the package -// TODO: Retain signature structure in ImportResult (input + output-expression bundles) -public class ImportResult { - - private final List<RankingExpression> expressions = new ArrayList<>(); - private final Map<String, Tensor> constants = new HashMap<>(); - private final Map<String, TensorType> arguments = new HashMap<>(); - private final List<String> warnings = new ArrayList<>(); - - void add(RankingExpression expression) { expressions.add(expression); } - void set(String name, Tensor constant) { constants.put(name, constant); } - void set(String name, TensorType argument) { arguments.put(name, argument); } - void warn(String warning) { warnings.add(warning); } - - /** Returns an immutable list of the expressions of this */ - public List<RankingExpression> expressions() { return Collections.unmodifiableList(expressions); } - - /** Returns an immutable map of the constants of this */ - public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); } - - /** Returns an immutable map of the arguments of this */ - public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } - - /** Returns an immutable list, in natural sort order of the warnings generated while importing this */ - public List<String> warnings() { - return warnings.stream().sorted().collect(Collectors.toList()); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java deleted file mode 100644 index e7f7b5ef2f4..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ /dev/null @@ -1,160 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.google.common.collect.ImmutableList; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.Matmul; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.Softmax; -import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; - -import java.util.ArrayList; -import java.util.List; -import java.util.function.DoubleBinaryOperator; -import java.util.function.DoubleUnaryOperator; - -/** - * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions. - * - * @author bratseth - */ -class OperationMapper { - - /* - A note on conversion from implicitly numbered to explicitly named dimensions: - Vespa tensor dimensions are explicitly named and thus have an explicit notion of being - 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation - comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation - around dimension renaming operations which mirrors those built into the TF operation definitions. - - To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost' - dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation - and the result is then renamed again (if necessary) to recover this convention across a full nested - computation. - - This requires us to track tensor types throughout the conversion. - */ - - private TensorConverter tensorConverter = new TensorConverter(); - - TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) { - ensureArguments(2, arguments, "join"); - TypedTensorFunction a = arguments.get(0); - TypedTensorFunction b = arguments.get(1); - if (a.type().rank() < b.type().rank()) - throw new IllegalArgumentException("Attempt to join " + a.type() + " and " + b.type() + ", " + - "but this is not supported when the second argument has a higher rank"); - - TensorFunction bFunction = b.function(); - - if (a.type().rank() > b.type().rank()) { - // Well now we have entered the wonderful world of "broadcasting" - // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html - // I'm not able to extract from that any unambiguous specification of which dimensions - // should be "stretched" when the tensor do not have the same number of dimensions. - // From trying this with TensorFlow it appears that the second tensor is matched to the - // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. - // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). - List<String> renameFrom = new ArrayList<>(); - List<String> renameTo = new ArrayList<>(); - int sizeDifference = a.type().rank() - b.type().rank(); - for (int i = 0; i < b.type().rank(); i++) { - renameFrom.add(b.type().dimensions().get(i).name()); - renameTo.add("d" + (sizeDifference + i)); - } - bFunction = new Rename(bFunction, renameFrom, renameTo); - } - - Join function = new Join(a.function(), bFunction, doubleFunction); - return new TypedTensorFunction(a.type(), function); // output type is a type by TF definition and a.rank>=b.rank - } - - TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) { - ensureArguments(1, arguments, "apply"); - TypedTensorFunction a = arguments.get(0); - - TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type()); - com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction); - return new TypedTensorFunction(resultType, function); - } - - TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) { - String name = tfNode.getName(); - TensorType type = result.arguments().get(name); - if (type == null) - throw new IllegalArgumentException("An placeholder operation node is referencing input '" + name + - "', but there is no such input"); - // Included literally in the expression and so must be produced by a separate macro in the rank profile - return new TypedTensorFunction(type, new VariableTensor(name)); - } - - TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) { - if ( ! tfNode.getName().endsWith("/read")) - throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " + - "nodes are only supported when reading variables"); - if (tfNode.getInputList().size() != 1) - throw new IllegalArgumentException("A Variable/read node must have one input but has " + - tfNode.getInputList().size()); - - String name = tfNode.getInput(0); - AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); - if (shapes == null) - throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape"); - Session.Runner fetched = model.session().runner().fetch(name); - List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); - if ( importedTensors.size() != 1) - throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + - importedTensors.size()); - Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0)); - result.set(name, constant); - return new TypedTensorFunction(constant.type(), - new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")"))); - } - - TypedTensorFunction matmul(List<TypedTensorFunction> arguments) { - ensureArguments(2, arguments, "matmul"); - TypedTensorFunction a = arguments.get(0); - TypedTensorFunction b = arguments.get(1); - if (a.type().rank() < 2 || b.type().rank() < 2) - throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); - if (a.type().rank() != b.type().rank()) - throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - - String afterLastDim = "d" + (a.type().rank() + 1); - // Let the first dimension of the second tensor be the same as the second dimension of the first - // and the second dimension of the second argument be not present in the first argument, while leaving the - // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication. - - // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly - - Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"), - ImmutableList.of("d1", afterLastDim)); - Matmul matmul = new Matmul(a.function(), renamedB, "d1"); - return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"), - new Rename(matmul, afterLastDim, "d1")); - } - - TypedTensorFunction softmax(List<TypedTensorFunction> arguments) { - ensureArguments(1, arguments, "softmax"); - TypedTensorFunction a = arguments.get(0); - // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1 - String dimension = "d" + (a.type().rank() - 1); - Softmax softmax = new Softmax(a.function(), dimension); - return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax); - } - - private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) { - if ( arguments.size() != count) - throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName + - ", but got " + arguments.size()); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java deleted file mode 100644 index df43225c333..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java +++ /dev/null @@ -1,94 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.tensor.IndexedTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; - -/** - * @author bratseth - */ -public class TensorConverter { - - public Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) { - TensorType type = toVespaTensorType(tfTensor.shape()); - Values values = readValuesOf(tfTensor); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); - for (int i = 0; i < values.size(); i++) - builder.cellByDirectIndex(i, values.get(i)); - return builder.build(); - } - - private TensorType toVespaTensorType(long[] shape) { - TensorType.Builder b = new TensorType.Builder(); - int dimensionIndex = 0; - for (long dimensionSize : shape) { - if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... - b.indexed("d" + (dimensionIndex++), (int) dimensionSize); - } - return b.build(); - } - - private Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) { - switch (tfTensor.dataType()) { - case DOUBLE: return new DoubleValues(tfTensor); - case FLOAT: return new FloatValues(tfTensor); - // TODO: The rest - default: - throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + - tfTensor.dataType() + " to a Vespa tensor"); - } - } - - /** Allows reading values from buffers of various numeric types as bytes */ - private static abstract class Values { - - private final int size; - - protected Values(int size) { - this.size = size; - } - - abstract double get(int i); - - int size() { return size; } - - } - - private static class DoubleValues extends Values { - - private final DoubleBuffer values; - - DoubleValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = DoubleBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - - @Override - double get(int i) { - return values.get(i); - } - - } - - private static class FloatValues extends Values { - - private final FloatBuffer values; - - FloatValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = FloatBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - - @Override - double get(int i) { - return values.get(i); - } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java deleted file mode 100644 index 33523244129..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ /dev/null @@ -1,147 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.yolean.Exceptions; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.MetaGraphDef; -import org.tensorflow.framework.NodeDef; -import org.tensorflow.framework.SignatureDef; -import org.tensorflow.framework.TensorInfo; -import org.tensorflow.framework.TensorShapeProto; - -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -/** - * Converts a saved TensorFlow model into a ranking expression and set of constants. - * - * @author bratseth - */ -public class TensorFlowImporter { - - private final OperationMapper operationMapper = new OperationMapper(); - - /** - * Imports a saved TensorFlow model from a directory. - * The model should be saved as a pbtxt file. - * The name of the model is taken as the db/pbtxt file name (not including the file ending). - * - * @param modelDir the directory containing the TensorFlow model files to import - */ - public ImportResult importModel(String modelDir) { - try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model); - } - catch (IOException e) { - throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e); - } - } - - public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) { - try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef()); - SignatureDef signature = graph.getSignatureDefMap().get(inputSignatureName); - ImportResult result = new ImportResult(); - importInputs(signature.getInputsMap(), result); - result.add(new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result))); - return result; - } - catch (IOException e) { - throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e); - } - } - - private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) { - ImportResult result = new ImportResult(); - for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { - importInputs(signatureEntry.getValue().getInputsMap(), result); - for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { - try { - ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result); - result.add(new RankingExpression(output.getKey(), node)); - } - catch (IllegalArgumentException e) { - result.warn("Skipping output '" + output.getValue().getName() + "' of signature '" + - signatureEntry.getValue().getMethodName() + - "': " + Exceptions.toMessageString(e)); - } - } - } - return result; - } - - private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult result) { - inputInfoMap.forEach((key, value) -> result.set(nameOf(value.getName()), - importTensorType(value.getTensorShape()))); - } - - private TensorType importTensorType(TensorShapeProto tensorShape) { - TensorType.Builder b = new TensorType.Builder(); - for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) { - int dimensionSize = (int)dimension.getSize(); - if (dimensionSize >= 0) - b.indexed("d" + b.rank(), dimensionSize); - else - b.indexed("d" + b.rank()); // unbound size - } - return b.build(); - } - - private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) { - return importNode(nameOf(output.getName()), graph, model, result); - } - - private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) { - TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function(); - return new TensorFunctionNode(function); // wrap top level (only) as an expression - } - - /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { - return tensorFunctionOf(tfNode, graph, model, result); - } - - private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { - // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops - // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ - switch (tfNode.getOp().toLowerCase()) { - case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add()); - case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos()); - case "placeholder" : return operationMapper.placeholder(tfNode, result); - case "identity" : return operationMapper.identity(tfNode, model, result); - case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result)); - case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result)); - default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); - } - } - - private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { - return tfNode.getInputList().stream() - .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result)) - .collect(Collectors.toList()); - } - - private NodeDef getNode(String name, GraphDef graph) { - return graph.getNodeList().stream() - .filter(node -> node.getName().equals(name)) - .findFirst() - .orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'")); - } - - /** - * A method signature input and output has the form name:index. - * This returns the name part without the index. - */ - private String nameOf(String name) { - return name.split(":")[0]; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java deleted file mode 100644 index 5712da77700..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java +++ /dev/null @@ -1,24 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.TensorFunction; - -/** - * A tensor function returning a specific tensor type - * - * @author bratseth - */ -final class TypedTensorFunction { - - private final TensorType type; - private final TensorFunction function; - - public TypedTensorFunction(TensorType type, TensorFunction function) { - this.type = type; - this.function = function; - } - - public TensorType type() { return type; } - public TensorFunction function() { return function; } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java index d366c9bfbe5..71699b379b2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -14,23 +14,23 @@ import java.util.function.*; /** * A tensor generating function, whose arguments are determined by a tensor type - * + * * @author bratseth */ public class GeneratorLambdaFunctionNode extends CompositeNode { private final TensorType type; private final ExpressionNode generator; - + public GeneratorLambdaFunctionNode(TensorType type, ExpressionNode generator) { if ( ! type.dimensions().stream().allMatch(d -> d.size().isPresent())) - throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " + + throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " + "dimensions, but tried to generate " + type); // TODO: Verify that the function only accesses the given arguments this.type = type; this.generator = generator; } - + @Override public List<ExpressionNode> children() { return Collections.singletonList(generator); @@ -53,8 +53,8 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { public Value evaluate(Context context) { return generator.evaluate(context); } - - /** + + /** * Returns this as an operator which converts a list of integers into a double */ public IntegerListToDoubleLambda asIntegerListToDoubleOperator() { @@ -70,7 +70,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { context.put(type.dimensions().get(i).name(), arguments.get(i)); return evaluate(context).asDouble(); } - + @Override public String toString() { return GeneratorLambdaFunctionNode.this.toString(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index ba765d07094..1f8db6e036c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -17,7 +17,7 @@ import java.util.Map; * @author bratseth */ public class SerializationContext { - + /** Expression functions indexed by name */ private final ImmutableMap<String, ExpressionFunction> functions; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 8af3448ca6f..ce21e132980 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -21,32 +21,22 @@ import java.util.stream.Collectors; * * @author bratseth */ -@Beta + @Beta public class TensorFunctionNode extends CompositeNode { private final TensorFunction function; - + public TensorFunctionNode(TensorFunction function) { this.function = function; } - /** Returns the tensor function wrapped by this */ - public TensorFunction function() { return function; } - @Override public List<ExpressionNode> children() { return function.functionArguments().stream() - .map(this::toExpressionNode) + .map(f -> ((TensorFunctionExpressionNode)f).expression) .collect(Collectors.toList()); } - private ExpressionNode toExpressionNode(TensorFunction f) { - if (f instanceof TensorFunctionExpressionNode) - return ((TensorFunctionExpressionNode)f).expression; - else - return new TensorFunctionNode(f); - } - @Override public CompositeNode setChildren(List<ExpressionNode> children) { List<TensorFunction> wrappedChildren = children.stream() @@ -60,7 +50,7 @@ public class TensorFunctionNode extends CompositeNode { // Serialize as primitive return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this)); } - + @Override public Value evaluate(Context context) { return new TensorValue(function.evaluate(context)); @@ -69,8 +59,8 @@ public class TensorFunctionNode extends CompositeNode { public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { return new TensorFunctionExpressionNode(node); } - - /** + + /** * A tensor function implemented by an expression. * This allows us to pass expressions as tensor function arguments. */ @@ -78,13 +68,13 @@ public class TensorFunctionNode extends CompositeNode { /** An expression which produces a tensor */ private final ExpressionNode expression; - + public TensorFunctionExpressionNode(ExpressionNode expression) { this.expression = expression; } - + @Override - public List<TensorFunction> functionArguments() { + public List<TensorFunction> functionArguments() { if (expression instanceof CompositeNode) return ((CompositeNode)expression).children().stream() .map(TensorFunctionExpressionNode::new) @@ -118,7 +108,7 @@ public class TensorFunctionNode extends CompositeNode { public String toString() { return toString(ExpressionNodeToStringContext.empty); } - + @Override public String toString(ToStringContext c) { if (c instanceof ExpressionNodeToStringContext) { @@ -131,14 +121,14 @@ public class TensorFunctionNode extends CompositeNode { } } - + /** Allows passing serialization context arguments through TensorFunctions */ private static class ExpressionNodeToStringContext implements ToStringContext { - + final SerializationContext context; final Deque<String> path; final CompositeNode parent; - + public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null); public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py deleted file mode 100644 index a1861a1c981..00000000000 --- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""A very simple MNIST classifier. - -See extensive documentation at -https://www.tensorflow.org/get_started/mnist/beginners -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import sys - -from tensorflow.examples.tutorials.mnist import input_data - -import tensorflow as tf - -FLAGS = None - - -def main(_): - # Import data - mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) - - # Create the model - x = tf.placeholder(tf.float32, [None, 784]) - W = tf.Variable(tf.zeros([784, 10])) - b = tf.Variable(tf.zeros([10])) - y = tf.matmul(x, W) + b - - # Define loss and optimizer - y_ = tf.placeholder(tf.float32, [None, 10]) - - # The raw formulation of cross-entropy, - # - # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), - # reduction_indices=[1])) - # - # can be numerically unstable. - # - # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw - # outputs of 'y', and then average across the batch. - cross_entropy = tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) - train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) - - sess = tf.InteractiveSession() - tf.global_variables_initializer().run() - # Train - for _ in range(1000): - batch_xs, batch_ys = mnist.train.next_batch(100) - sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) - - # Test trained model - correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) - accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - print(sess.run(accuracy, feed_dict={x: mnist.test.images, - y_: mnist.test.labels})) - - # Save the model - export_path = "saved" - print('Exporting trained model to ', export_path) - builder = tf.saved_model.builder.SavedModelBuilder(export_path) - signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':x}, outputs = {'y':y}) - builder.add_meta_graph_and_variables(sess, - [tf.saved_model.tag_constants.SERVING], - signature_def_map={'serving_default':signature}) - builder.save(as_text=True) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', - help='Directory for storing input data') - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt deleted file mode 100644 index 8100dfd594d..00000000000 --- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt +++ /dev/null @@ -1,5039 +0,0 @@ -saved_model_schema_version: 1 -meta_graphs { - meta_info_def { - stripped_op_list { - op { - name: "Add" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_UINT8 - type: DT_INT8 - type: DT_INT16 - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_STRING - } - } - } - } - op { - name: "ApplyGradientDescent" - input_arg { - name: "var" - type_attr: "T" - is_ref: true - } - input_arg { - name: "alpha" - type_attr: "T" - } - input_arg { - name: "delta" - type_attr: "T" - } - output_arg { - name: "out" - type_attr: "T" - is_ref: true - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "use_locking" - type: "bool" - default_value { - b: false - } - } - } - op { - name: "ArgMax" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "dimension" - type_attr: "Tidx" - } - output_arg { - name: "output" - type_attr: "output_type" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - attr { - name: "output_type" - type: "type" - default_value { - type: DT_INT64 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "Assign" - input_arg { - name: "ref" - type_attr: "T" - is_ref: true - } - input_arg { - name: "value" - type_attr: "T" - } - output_arg { - name: "output_ref" - type_attr: "T" - is_ref: true - } - attr { - name: "T" - type: "type" - } - attr { - name: "validate_shape" - type: "bool" - default_value { - b: true - } - } - attr { - name: "use_locking" - type: "bool" - default_value { - b: true - } - } - allows_uninitialized_input: true - } - op { - name: "BroadcastGradientArgs" - input_arg { - name: "s0" - type_attr: "T" - } - input_arg { - name: "s1" - type_attr: "T" - } - output_arg { - name: "r0" - type_attr: "T" - } - output_arg { - name: "r1" - type_attr: "T" - } - attr { - name: "T" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "Cast" - input_arg { - name: "x" - type_attr: "SrcT" - } - output_arg { - name: "y" - type_attr: "DstT" - } - attr { - name: "SrcT" - type: "type" - } - attr { - name: "DstT" - type: "type" - } - } - op { - name: "ConcatV2" - input_arg { - name: "values" - type_attr: "T" - number_attr: "N" - } - input_arg { - name: "axis" - type_attr: "Tidx" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "N" - type: "int" - has_minimum: true - minimum: 2 - } - attr { - name: "T" - type: "type" - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "Const" - output_arg { - name: "output" - type_attr: "dtype" - } - attr { - name: "value" - type: "tensor" - } - attr { - name: "dtype" - type: "type" - } - } - op { - name: "Equal" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type: DT_BOOL - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_UINT8 - type: DT_INT8 - type: DT_INT16 - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_QUINT8 - type: DT_QINT8 - type: DT_QINT32 - type: DT_STRING - type: DT_BOOL - type: DT_COMPLEX128 - } - } - } - is_commutative: true - } - op { - name: "ExpandDims" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "dim" - type_attr: "Tdim" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - attr { - name: "Tdim" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "Fill" - input_arg { - name: "dims" - type: DT_INT32 - } - input_arg { - name: "value" - type_attr: "T" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - } - op { - name: "FloorDiv" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_UINT8 - type: DT_INT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - } - op { - name: "Identity" - input_arg { - name: "input" - type_attr: "T" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - } - op { - name: "MatMul" - input_arg { - name: "a" - type_attr: "T" - } - input_arg { - name: "b" - type_attr: "T" - } - output_arg { - name: "product" - type_attr: "T" - } - attr { - name: "transpose_a" - type: "bool" - default_value { - b: false - } - } - attr { - name: "transpose_b" - type: "bool" - default_value { - b: false - } - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT32 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - } - op { - name: "Maximum" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 - } - } - } - is_commutative: true - } - op { - name: "Mean" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "reduction_indices" - type_attr: "Tidx" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "keep_dims" - type: "bool" - default_value { - b: false - } - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "MergeV2Checkpoints" - input_arg { - name: "checkpoint_prefixes" - type: DT_STRING - } - input_arg { - name: "destination_prefix" - type: DT_STRING - } - attr { - name: "delete_old_dirs" - type: "bool" - default_value { - b: true - } - } - is_stateful: true - } - op { - name: "Mul" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_UINT8 - type: DT_INT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - is_commutative: true - } - op { - name: "NoOp" - } - op { - name: "Pack" - input_arg { - name: "values" - type_attr: "T" - number_attr: "N" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "N" - type: "int" - has_minimum: true - minimum: 1 - } - attr { - name: "T" - type: "type" - } - attr { - name: "axis" - type: "int" - default_value { - i: 0 - } - } - } - op { - name: "Placeholder" - output_arg { - name: "output" - type_attr: "dtype" - } - attr { - name: "dtype" - type: "type" - } - attr { - name: "shape" - type: "shape" - default_value { - shape { - unknown_rank: true - } - } - } - } - op { - name: "Prod" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "reduction_indices" - type_attr: "Tidx" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "keep_dims" - type: "bool" - default_value { - b: false - } - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "RealDiv" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_UINT8 - type: DT_INT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - } - op { - name: "Reshape" - input_arg { - name: "tensor" - type_attr: "T" - } - input_arg { - name: "shape" - type_attr: "Tshape" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - attr { - name: "Tshape" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "RestoreV2" - input_arg { - name: "prefix" - type: DT_STRING - } - input_arg { - name: "tensor_names" - type: DT_STRING - } - input_arg { - name: "shape_and_slices" - type: DT_STRING - } - output_arg { - name: "tensors" - type_list_attr: "dtypes" - } - attr { - name: "dtypes" - type: "list(type)" - has_minimum: true - minimum: 1 - } - is_stateful: true - } - op { - name: "SaveV2" - input_arg { - name: "prefix" - type: DT_STRING - } - input_arg { - name: "tensor_names" - type: DT_STRING - } - input_arg { - name: "shape_and_slices" - type: DT_STRING - } - input_arg { - name: "tensors" - type_list_attr: "dtypes" - } - attr { - name: "dtypes" - type: "list(type)" - has_minimum: true - minimum: 1 - } - is_stateful: true - } - op { - name: "Shape" - input_arg { - name: "input" - type_attr: "T" - } - output_arg { - name: "output" - type_attr: "out_type" - } - attr { - name: "T" - type: "type" - } - attr { - name: "out_type" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "ShardedFilename" - input_arg { - name: "basename" - type: DT_STRING - } - input_arg { - name: "shard" - type: DT_INT32 - } - input_arg { - name: "num_shards" - type: DT_INT32 - } - output_arg { - name: "filename" - type: DT_STRING - } - } - op { - name: "Slice" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "begin" - type_attr: "Index" - } - input_arg { - name: "size" - type_attr: "Index" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - attr { - name: "Index" - type: "type" - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "SoftmaxCrossEntropyWithLogits" - input_arg { - name: "features" - type_attr: "T" - } - input_arg { - name: "labels" - type_attr: "T" - } - output_arg { - name: "loss" - type_attr: "T" - } - output_arg { - name: "backprop" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - } - } - } - } - op { - name: "StringJoin" - input_arg { - name: "inputs" - type: DT_STRING - number_attr: "N" - } - output_arg { - name: "output" - type: DT_STRING - } - attr { - name: "N" - type: "int" - has_minimum: true - minimum: 1 - } - attr { - name: "separator" - type: "string" - default_value { - s: "" - } - } - } - op { - name: "Sub" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_UINT8 - type: DT_INT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - } - op { - name: "Sum" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "reduction_indices" - type_attr: "Tidx" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "keep_dims" - type: "bool" - default_value { - b: false - } - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "Tile" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "multiples" - type_attr: "Tmultiples" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - attr { - name: "Tmultiples" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - } - op { - name: "VariableV2" - output_arg { - name: "ref" - type_attr: "dtype" - is_ref: true - } - attr { - name: "shape" - type: "shape" - } - attr { - name: "dtype" - type: "type" - } - attr { - name: "container" - type: "string" - default_value { - s: "" - } - } - attr { - name: "shared_name" - type: "string" - default_value { - s: "" - } - } - is_stateful: true - } - op { - name: "ZerosLike" - input_arg { - name: "x" - type_attr: "T" - } - output_arg { - name: "y" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - } - } - tags: "serve" - tensorflow_version: "1.4.1" - tensorflow_git_version: "v1.4.0-19-ga52c8d9b01" - } - graph_def { - node { - name: "Placeholder" - op: "Placeholder" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - } - } - } - node { - name: "zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - float_val: 0.0 - } - } - } - } - node { - name: "Variable" - op: "VariableV2" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } - } - node { - name: "Variable/Assign" - op: "Assign" - input: "Variable" - input: "zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "Variable/read" - op: "Identity" - input: "Variable" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "zeros_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 10 - } - } - float_val: 0.0 - } - } - } - } - node { - name: "Variable_1" - op: "VariableV2" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } - } - node { - name: "Variable_1/Assign" - op: "Assign" - input: "Variable_1" - input: "zeros_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "Variable_1/read" - op: "Identity" - input: "Variable_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - } - node { - name: "MatMul" - op: "MatMul" - input: "Placeholder" - input: "Variable/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: false - } - } - } - node { - name: "add" - op: "Add" - input: "MatMul" - input: "Variable_1/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "Placeholder_1" - op: "Placeholder" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - node { - name: "Rank" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } - } - node { - name: "Shape" - op: "Shape" - input: "add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node { - name: "Rank_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } - } - node { - name: "Shape_1" - op: "Shape" - input: "add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node { - name: "Sub/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node { - name: "Sub" - op: "Sub" - input: "Rank_1" - input: "Sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "Slice/begin" - op: "Pack" - input: "Sub" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } - } - node { - name: "Slice/size" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node { - name: "Slice" - op: "Slice" - input: "Shape_1" - input: "Slice/begin" - input: "Slice/size" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - } - node { - name: "concat/values_0" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } - } - node { - name: "concat/axis" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - } - node { - name: "concat" - op: "ConcatV2" - input: "concat/values_0" - input: "Slice" - input: "concat/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - } - node { - name: "Reshape" - op: "Reshape" - input: "add" - input: "concat" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - } - node { - name: "Rank_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } - } - node { - name: "Shape_2" - op: "Shape" - input: "Placeholder_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node { - name: "Sub_1/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node { - name: "Sub_1" - op: "Sub" - input: "Rank_2" - input: "Sub_1/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "Slice_1/begin" - op: "Pack" - input: "Sub_1" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } - } - node { - name: "Slice_1/size" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node { - name: "Slice_1" - op: "Slice" - input: "Shape_2" - input: "Slice_1/begin" - input: "Slice_1/size" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - } - node { - name: "concat_1/values_0" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } - } - node { - name: "concat_1/axis" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - } - node { - name: "concat_1" - op: "ConcatV2" - input: "concat_1/values_0" - input: "Slice_1" - input: "concat_1/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - } - node { - name: "Reshape_1" - op: "Reshape" - input: "Placeholder_1" - input: "concat_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - } - node { - name: "SoftmaxCrossEntropyWithLogits" - op: "SoftmaxCrossEntropyWithLogits" - input: "Reshape" - input: "Reshape_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - } - node { - name: "Sub_2/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node { - name: "Sub_2" - op: "Sub" - input: "Rank" - input: "Sub_2/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "Slice_2/begin" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node { - name: "Slice_2/size" - op: "Pack" - input: "Sub_2" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } - } - node { - name: "Slice_2" - op: "Slice" - input: "Shape" - input: "Slice_2/begin" - input: "Slice_2/size" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - } - node { - name: "Reshape_2" - op: "Reshape" - input: "SoftmaxCrossEntropyWithLogits" - input: "Slice_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - } - node { - name: "Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node { - name: "Mean" - op: "Mean" - input: "Reshape_2" - input: "Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "gradients/Shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } - } - node { - name: "gradients/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } - } - node { - name: "gradients/Fill" - op: "Fill" - input: "gradients/Shape" - input: "gradients/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "gradients/Mean_grad/Reshape/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node { - name: "gradients/Mean_grad/Reshape" - op: "Reshape" - input: "gradients/Fill" - input: "gradients/Mean_grad/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - } - node { - name: "gradients/Mean_grad/Shape" - op: "Shape" - input: "Reshape_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node { - name: "gradients/Mean_grad/Tile" - op: "Tile" - input: "gradients/Mean_grad/Reshape" - input: "gradients/Mean_grad/Shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tmultiples" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - } - node { - name: "gradients/Mean_grad/Shape_1" - op: "Shape" - input: "Reshape_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node { - name: "gradients/Mean_grad/Shape_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } - } - node { - name: "gradients/Mean_grad/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@gradients/Mean_grad/Shape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node { - name: "gradients/Mean_grad/Prod" - op: "Prod" - input: "gradients/Mean_grad/Shape_1" - input: "gradients/Mean_grad/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/Mean_grad/Shape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "gradients/Mean_grad/Const_1" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@gradients/Mean_grad/Shape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node { - name: "gradients/Mean_grad/Prod_1" - op: "Prod" - input: "gradients/Mean_grad/Shape_2" - input: "gradients/Mean_grad/Const_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/Mean_grad/Shape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "gradients/Mean_grad/Maximum/y" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@gradients/Mean_grad/Shape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node { - name: "gradients/Mean_grad/Maximum" - op: "Maximum" - input: "gradients/Mean_grad/Prod_1" - input: "gradients/Mean_grad/Maximum/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/Mean_grad/Shape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "gradients/Mean_grad/floordiv" - op: "FloorDiv" - input: "gradients/Mean_grad/Prod" - input: "gradients/Mean_grad/Maximum" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/Mean_grad/Shape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "gradients/Mean_grad/Cast" - op: "Cast" - input: "gradients/Mean_grad/floordiv" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "gradients/Mean_grad/truediv" - op: "RealDiv" - input: "gradients/Mean_grad/Tile" - input: "gradients/Mean_grad/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - } - node { - name: "gradients/Reshape_2_grad/Shape" - op: "Shape" - input: "SoftmaxCrossEntropyWithLogits" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node { - name: "gradients/Reshape_2_grad/Reshape" - op: "Reshape" - input: "gradients/Mean_grad/truediv" - input: "gradients/Reshape_2_grad/Shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - } - node { - name: "gradients/zeros_like" - op: "ZerosLike" - input: "SoftmaxCrossEntropyWithLogits:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - } - node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: -1 - } - } - } - } - node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" - op: "ExpandDims" - input: "gradients/Reshape_2_grad/Reshape" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 1 - } - } - } - } - } - } - node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" - op: "Mul" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" - input: "SoftmaxCrossEntropyWithLogits:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - } - node { - name: "gradients/Reshape_grad/Shape" - op: "Shape" - input: "add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node { - name: "gradients/Reshape_grad/Reshape" - op: "Reshape" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" - input: "gradients/Reshape_grad/Shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/add_grad/Shape" - op: "Shape" - input: "MatMul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node { - name: "gradients/add_grad/Shape_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 10 - } - } - } - } - node { - name: "gradients/add_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_grad/Shape" - input: "gradients/add_grad/Shape_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } - } - node { - name: "gradients/add_grad/Sum" - op: "Sum" - input: "gradients/Reshape_grad/Reshape" - input: "gradients/add_grad/BroadcastGradientArgs" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "gradients/add_grad/Reshape" - op: "Reshape" - input: "gradients/add_grad/Sum" - input: "gradients/add_grad/Shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/add_grad/Sum_1" - op: "Sum" - input: "gradients/Reshape_grad/Reshape" - input: "gradients/add_grad/BroadcastGradientArgs:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "gradients/add_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_grad/Sum_1" - input: "gradients/add_grad/Shape_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/add_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_grad/Reshape" - input: "^gradients/add_grad/Reshape_1" - } - node { - name: "gradients/add_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_grad/Reshape" - input: "^gradients/add_grad/tuple/group_deps" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/add_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_grad/Reshape_1" - input: "^gradients/add_grad/tuple/group_deps" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - } - node { - name: "gradients/MatMul_grad/MatMul" - op: "MatMul" - input: "gradients/add_grad/tuple/control_dependency" - input: "Variable/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } - } - node { - name: "gradients/MatMul_grad/MatMul_1" - op: "MatMul" - input: "Placeholder" - input: "gradients/add_grad/tuple/control_dependency" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" - value { - b: false - } - } - } - node { - name: "gradients/MatMul_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_grad/MatMul" - input: "^gradients/MatMul_grad/MatMul_1" - } - node { - name: "gradients/MatMul_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/MatMul_grad/MatMul" - input: "^gradients/MatMul_grad/tuple/group_deps" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_grad/MatMul" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - } - } - } - } - node { - name: "gradients/MatMul_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_grad/MatMul_1" - input: "^gradients/MatMul_grad/tuple/group_deps" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_grad/MatMul_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - } - node { - name: "GradientDescent/learning_rate" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.5 - } - } - } - } - node { - name: "GradientDescent/update_Variable/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "Variable" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_grad/tuple/control_dependency_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } - } - node { - name: "GradientDescent/update_Variable_1/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "Variable_1" - input: "GradientDescent/learning_rate" - input: "gradients/add_grad/tuple/control_dependency_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } - } - node { - name: "GradientDescent" - op: "NoOp" - input: "^GradientDescent/update_Variable/ApplyGradientDescent" - input: "^GradientDescent/update_Variable_1/ApplyGradientDescent" - } - node { - name: "init" - op: "NoOp" - input: "^Variable/Assign" - input: "^Variable_1/Assign" - } - node { - name: "ArgMax/dimension" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node { - name: "ArgMax" - op: "ArgMax" - input: "add" - input: "ArgMax/dimension" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_type" - value { - type: DT_INT64 - } - } - } - node { - name: "ArgMax_1/dimension" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node { - name: "ArgMax_1" - op: "ArgMax" - input: "Placeholder_1" - input: "ArgMax_1/dimension" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_type" - value { - type: DT_INT64 - } - } - } - node { - name: "Equal" - op: "Equal" - input: "ArgMax" - input: "ArgMax_1" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - } - node { - name: "Cast_1" - op: "Cast" - input: "Equal" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - } - node { - name: "Const_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node { - name: "Mean_1" - op: "Mean" - input: "Cast_1" - input: "Const_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node { - name: "save/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "model" - } - } - } - } - node { - name: "save/StringJoin/inputs_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "_temp_6ca9fa5171ed4237a2fbcc27277e2864/part" - } - } - } - } - node { - name: "save/StringJoin" - op: "StringJoin" - input: "save/Const" - input: "save/StringJoin/inputs_1" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "separator" - value { - s: "" - } - } - } - node { - name: "save/num_shards" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node { - name: "save/ShardedFilename/shard" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - } - node { - name: "save/ShardedFilename" - op: "ShardedFilename" - input: "save/StringJoin" - input: "save/ShardedFilename/shard" - input: "save/num_shards" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "save/SaveV2/tensor_names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 2 - } - } - string_val: "Variable" - string_val: "Variable_1" - } - } - } - } - node { - name: "save/SaveV2/shape_and_slices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 2 - } - } - string_val: "" - string_val: "" - } - } - } - } - node { - name: "save/SaveV2" - op: "SaveV2" - input: "save/ShardedFilename" - input: "save/SaveV2/tensor_names" - input: "save/SaveV2/shape_and_slices" - input: "Variable" - input: "Variable_1" - attr { - key: "dtypes" - value { - list { - type: DT_FLOAT - type: DT_FLOAT - } - } - } - } - node { - name: "save/control_dependency" - op: "Identity" - input: "save/ShardedFilename" - input: "^save/SaveV2" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@save/ShardedFilename" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "save/MergeV2Checkpoints/checkpoint_prefixes" - op: "Pack" - input: "save/ShardedFilename" - input: "^save/control_dependency" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } - } - node { - name: "save/MergeV2Checkpoints" - op: "MergeV2Checkpoints" - input: "save/MergeV2Checkpoints/checkpoint_prefixes" - input: "save/Const" - attr { - key: "delete_old_dirs" - value { - b: true - } - } - } - node { - name: "save/Identity" - op: "Identity" - input: "save/Const" - input: "^save/control_dependency" - input: "^save/MergeV2Checkpoints" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "save/RestoreV2/tensor_names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "Variable" - } - } - } - } - node { - name: "save/RestoreV2/shape_and_slices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "" - } - } - } - } - node { - name: "save/RestoreV2" - op: "RestoreV2" - input: "save/Const" - input: "save/RestoreV2/tensor_names" - input: "save/RestoreV2/shape_and_slices" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "dtypes" - value { - list { - type: DT_FLOAT - } - } - } - } - node { - name: "save/Assign" - op: "Assign" - input: "Variable" - input: "save/RestoreV2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "save/RestoreV2_1/tensor_names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "Variable_1" - } - } - } - } - node { - name: "save/RestoreV2_1/shape_and_slices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "" - } - } - } - } - node { - name: "save/RestoreV2_1" - op: "RestoreV2" - input: "save/Const" - input: "save/RestoreV2_1/tensor_names" - input: "save/RestoreV2_1/shape_and_slices" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "dtypes" - value { - list { - type: DT_FLOAT - } - } - } - } - node { - name: "save/Assign_1" - op: "Assign" - input: "Variable_1" - input: "save/RestoreV2_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Variable_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "save/restore_shard" - op: "NoOp" - input: "^save/Assign" - input: "^save/Assign_1" - } - node { - name: "save/restore_all" - op: "NoOp" - input: "^save/restore_shard" - } - versions { - producer: 24 - } - } - saver_def { - filename_tensor_name: "save/Const:0" - save_tensor_name: "save/Identity:0" - restore_op_name: "save/restore_all" - max_to_keep: 5 - sharded: true - keep_checkpoint_every_n_hours: 10000.0 - version: V2 - } - collection_def { - key: "train_op" - value { - node_list { - value: "GradientDescent" - } - } - } - collection_def { - key: "trainable_variables" - value { - bytes_list { - value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0" - value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0" - } - } - } - collection_def { - key: "variables" - value { - bytes_list { - value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0" - value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0" - } - } - } - signature_def { - key: "serving_default" - value { - inputs { - key: "x" - value { - name: "Placeholder:0" - dtype: DT_FLOAT - tensor_shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - } - } - outputs { - key: "y" - value { - name: "add:0" - dtype: DT_FLOAT - tensor_shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - method_name: "tensorflow/serving/predict" - } - } -} diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 Binary files differdeleted file mode 100644 index 8474aa0a04c..00000000000 --- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 +++ /dev/null diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index Binary files differdeleted file mode 100644 index cfcdac20409..00000000000 --- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index +++ /dev/null diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 3aa2d144f1f..82e5d0cfe5b 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -4,14 +4,9 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.Arguments; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.IfNode; +import com.yahoo.searchlib.rankingexpression.rule.*; +import com.yahoo.tensor.Tensor; import org.junit.Test; - import static org.junit.Assert.assertEquals; /** @@ -88,7 +83,7 @@ public class EvaluationTestCase { tester.assertEvaluates(0, "sin(0)"); tester.assertEvaluates(1, "cos(0)"); tester.assertEvaluates(8, "pow(4/2,min(cos(0)*3,5))"); - + // Random feature (which is also a tensor function) (We expect to be able to parse it and look up a zero) tester.assertEvaluates(0, "random(1)"); tester.assertEvaluates(0, "random(foo)"); @@ -157,7 +152,7 @@ public class EvaluationTestCase { "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", "!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }"); - + // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence) tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }"); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java index ba0db4de5e1..ee2b1c147e3 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java @@ -34,7 +34,7 @@ public class EvaluationTester { } // TODO: Test both bound and unbound indexed - public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors, + public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors, String ... tensorArgumentStrings) { MapContext context = defaultContext.thawedCopy(); int argumentIndex = 0; @@ -46,7 +46,7 @@ public class EvaluationTester { argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString); context.put("tensor" + (argumentIndex++), new TensorValue(argument)); } - return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context, + return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context, mappedTensors ? "Mapped tensors" : "Indexed tensors"); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java deleted file mode 100644 index 863479f3531..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java +++ /dev/null @@ -1,114 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.junit.Test; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; - -import java.nio.FloatBuffer; -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - -/** - * @author bratseth - */ -public class Mnist_SoftmaxTestCase { - - @Test - public void testImporting() { - String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved"; - ImportResult result = new TensorFlowImporter().importModel(modelDir); - - // Check logged messages - result.warnings().forEach(System.err::println); - assertEquals(0, result.warnings().size()); - - // Check arguments - assertEquals(1, result.arguments().size()); - TensorType argument0 = result.arguments().get("Placeholder"); - assertNotNull(argument0); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); - - // Check constants - assertEquals(2, result.constants().size()); - - Tensor constant0 = result.constants().get("Variable"); - assertNotNull(constant0); - assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), - constant0.type()); - assertEquals(7840, constant0.size()); - - Tensor constant1 = result.constants().get("Variable_1"); - assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d0", 10).build(), - constant1.type()); - assertEquals(10, constant1.size()); - - // Check resulting Vespa expression - assertEquals(1, result.expressions().size()); - assertEquals("y", result.expressions().get(0).getName()); - assertEquals("" + - "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " + - "rename(constant(Variable_1), d0, d1), " + - "f(a,b)(a + b))", - toNonPrimitiveString(result.expressions().get(0))); - - // Test execution - String signatureName = "serving_default"; - - assertEqualResult(modelDir, signatureName, "Variable/read"); - assertEqualResult(modelDir, signatureName, "Variable_1/read"); - // TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder"); - assertEqualResult(modelDir, signatureName, "MatMul"); - assertEqualResult(modelDir, signatureName, "add"); - } - - private void assertEqualResult(String modelDir, String signatureName, String operationName) { - ImportResult result = new TensorFlowImporter().importNode(modelDir, signatureName, operationName); - - Tensor tfResult = tensorFlowExecute(modelDir, operationName); - Context context = contextFrom(result); - Tensor placeholder = placeholderArgument(); - context.put("Placeholder", new TensorValue(placeholder)); - Tensor vespaResult = result.expressions().get(0).evaluate(context).asTensor(); - assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult); - } - - private Tensor tensorFlowExecute(String modelDir, String operationName) { - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - Session.Runner runner = model.session().runner(); - org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784)); - runner.feed("Placeholder", placeholder); - List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); - assertEquals(1, results.size()); - return new TensorConverter().toVespaTensor(results.get(0)); - } - - private Context contextFrom(ImportResult result) { - MapContext context = new MapContext(); - result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - return context; - } - - private String toNonPrimitiveString(RankingExpression expression) { - // toString on the wrapping expression will map to primitives, which is harder to read - return ((TensorFunctionNode)expression.getRoot()).function().toString(); - } - - private Tensor placeholderArgument() { - int size = 784; - Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build()); - for (int i = 0; i < size; i++) - b.cell(0, 0, i); - return b.build(); - } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java index 1960c1fe876..dde9d4bf21e 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java @@ -59,7 +59,7 @@ public class TensorConformanceTest { try { ObjectMapper mapper = new ObjectMapper(); JsonNode node = mapper.readTree(test); - + if (node.has("num_tests")) { Assert.assertEquals(node.get("num_tests").asInt(), count); return true; @@ -67,7 +67,7 @@ public class TensorConformanceTest { if (!node.has("expression")) { return true; // ignore } - + String expression = node.get("expression").asText(); MapContext context = getInput(node.get("inputs")); Tensor expect = getTensor(node.get("result").get("expect").asText()); diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java index 16a541f939c..7874dcb24ab 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java @@ -5,7 +5,6 @@ import com.google.common.annotations.Beta; import com.yahoo.vespa.http.client.Result; import com.yahoo.vespa.http.client.config.Endpoint; import com.yahoo.vespa.http.client.core.Document; -import com.yahoo.vespa.http.client.core.Exceptions; import com.yahoo.vespa.http.client.core.operationProcessor.EndPointResultFactory; import com.yahoo.vespa.http.client.core.EndpointResult; import com.yahoo.vespa.http.client.core.ServerResponseException; @@ -319,28 +318,29 @@ class IOThread implements Runnable, AutoCloseable { successfullHandshakes.getAndIncrement(); } catch (ServerResponseException ser) { executeProblemsCounter.incrementAndGet(); - log.log(Level.INFO, "Handshake did not work out " + endpoint, Exceptions.toMessageString(ser)); + log.log(Level.INFO, "Handshake did not work out " + endpoint, ser.getMessage()); drainFirstDocumentsInQueueIfOld(); return ThreadState.CONNECTED; } catch (Throwable throwable) { // This cover IOException as well executeProblemsCounter.incrementAndGet(); - log.log(Level.INFO, "Problem with Handshake " + endpoint, Exceptions.toMessageString(throwable)); + log.log(Level.INFO, "Problem with Handshake " + endpoint, throwable.getMessage()); drainFirstDocumentsInQueueIfOld(); client.close(); return ThreadState.DISCONNECTED; } return ThreadState.SESSION_SYNCED; case SESSION_SYNCED: + final int maxWaitTimeMilliSecs = 100; try { - ProcessResponse processResponse = pullAndProcessData(100); + ProcessResponse processResponse = pullAndProcessData(maxWaitTimeMilliSecs); gatewayThrottler.handleCall(processResponse.transitiveErrorCount); } catch (ServerResponseException ser) { - log.info("Problems while handing data over to gateway " + endpoint + ": " + Exceptions.toMessageString(ser)); + log.info("Problems while handing data over to gateway " + endpoint + " " + ser.getMessage()); return ThreadState.CONNECTED; } catch (Throwable e) { // Covers IOException as well - log.info("Problems while handing data over to gateway " + endpoint + ": " + Exceptions.toMessageString(e)); + log.info("Problems while handing data over to gateway " + endpoint + " " + e.getMessage()); client.close(); return ThreadState.DISCONNECTED; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index f6237a1977a..00e106dd035 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -7,7 +7,7 @@ import java.util.Arrays; /** * The sizes of a set of dimensions. - * + * * @author bratseth */ @Beta @@ -48,7 +48,7 @@ public final class DimensionSizes { @Override public int hashCode() { return Arrays.hashCode(sizes); } - /** + /** * Builder of a set of dimension sizes. * Dimensions whose size is not set before building will get size 0. */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 6b0d769de9f..c207dabca3a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -25,12 +25,12 @@ public class IndexedTensor implements Tensor { /** The prescribed and possibly abstract type this is an instance of */ private final TensorType type; - + /** The sizes of the dimensions of this in the order of the dimensions of the type */ private final DimensionSizes dimensionSizes; - + private final double[] values; - + private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) { this.type = type; this.dimensionSizes = dimensionSizes; @@ -43,8 +43,8 @@ public class IndexedTensor implements Tensor { } /** - * Returns an iterator over the cells of this. - * Cells are returned in order of increasing indexes in each dimension, increasing + * Returns an iterator over the cells of this. + * Cells are returned in order of increasing indexes in each dimension, increasing * indexes of later dimensions in the dimension type before earlier. */ @Override @@ -69,7 +69,7 @@ public class IndexedTensor implements Tensor { /** * Returns an iterator over the values of this. - * Values are returned in order of increasing indexes in each dimension, increasing + * Values are returned in order of increasing indexes in each dimension, increasing * indexes of later dimensions in the dimension type before earlier. */ @Override @@ -81,7 +81,7 @@ public class IndexedTensor implements Tensor { * Returns an iterator over value iterators where the outer iterator is over each unique value of the dimensions * given and the inner iterator is over each unique value of the rest of the dimensions, in the same order as * other iterator. - * + * * @param dimensions the names of the dimensions of the superspace * @param sizes the size of each dimension in the space we are returning values for, containing * one value per dimension of this tensor (in order). Each size may be the same or smaller @@ -96,9 +96,9 @@ public class IndexedTensor implements Tensor { return subspaceIterator(dimensions, dimensionSizes); } - /** + /** * Returns the value at the given indexes - * + * * @param indexes the indexes into the dimensions of this. Must be one number per dimension of this * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given */ @@ -119,7 +119,7 @@ public class IndexedTensor implements Tensor { } private double get(int valueIndex) { return values[valueIndex]; } - + private static int toValueIndex(int[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed if (indexes.length == 0) return 0; // for speed @@ -165,7 +165,7 @@ public class IndexedTensor implements Tensor { public Map<TensorAddress, Double> cells() { if (dimensionSizes.dimensions() == 0) return Collections.singletonMap(TensorAddress.of(), values[0]); - + ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>(); Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); for (int i = 0; i < values.length; i++) { @@ -174,13 +174,13 @@ public class IndexedTensor implements Tensor { } return builder.build(); } - + @Override public int hashCode() { return Arrays.hashCode(values); } @Override public String toString() { return Tensor.toStandardString(this); } - + @Override public boolean equals(Object other) { if ( ! ( other instanceof Tensor)) return false; @@ -188,9 +188,9 @@ public class IndexedTensor implements Tensor { } public abstract static class Builder implements Tensor.Builder { - + final TensorType type; - + private Builder(TensorType type) { this.type = type; } @@ -202,7 +202,7 @@ public class IndexedTensor implements Tensor { return new UnboundBuilder(type); } - /** + /** * Create a builder with dimension size information for this instance. Must be one size entry per dimension, * and, agree with the type size information when specified in the type. * If sizes are completely specified in the type this size information is redundant. @@ -210,16 +210,16 @@ public class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes) { // validate if (sizes.dimensions() != type.dimensions().size()) - throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " + + throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " + "for " + type); for (int i = 0; i < sizes.dimensions(); i++ ) { Optional<Integer> size = type.dimensions().get(i).size(); if (size.isPresent() && size.get() < sizes.size(i)) - throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + + throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + sizes.size(i) + " but cannot be larger than " + size.get() + " in " + type); } - + return new BoundBuilder(type, sizes); } @@ -232,7 +232,7 @@ public class IndexedTensor implements Tensor { public abstract IndexedTensor build(); } - + /** A bound builder can create the double array directly */ public static class BoundBuilder extends Builder { @@ -257,13 +257,13 @@ public class IndexedTensor implements Tensor { this.sizes = sizes; values = new double[sizes.totalSize()]; } - + @Override public BoundBuilder cell(double value, int ... indexes) { values[toValueIndex(indexes, sizes)] = value; return this; } - + @Override public CellBuilder cell() { return new CellBuilder(type, this); @@ -294,8 +294,8 @@ public class IndexedTensor implements Tensor { return this; } - /** - * Set a cell value by the index in the internal layout of this cell. + /** + * Set a cell value by the index in the internal layout of this cell. * This requires knowledge of the internal layout of cells in this implementation, and should therefore * probably not be used (but when it can be used it is fast). */ @@ -330,7 +330,7 @@ public class IndexedTensor implements Tensor { fillValues(0, 0, firstDimension, dimensionSizes, values); return new IndexedTensor(type, dimensionSizes, values); } - + private DimensionSizes findDimensionSizes(List<Object> firstDimension) { List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size()); findDimensionSizes(0, dimensionSizeList, firstDimension); @@ -347,16 +347,16 @@ public class IndexedTensor implements Tensor { if (currentDimensionIndex == dimensionSizes.size()) dimensionSizes.add(currentDimension.size()); else if (dimensionSizes.get(currentDimensionIndex) != currentDimension.size()) - throw new IllegalArgumentException("Missing values in dimension " + + throw new IllegalArgumentException("Missing values in dimension " + type.dimensions().get(currentDimensionIndex) + " in " + type); - + for (Object value : currentDimension) if (value instanceof List) findDimensionSizes(currentDimensionIndex + 1, dimensionSizes, (List<Object>)value); } @SuppressWarnings("unchecked") - private void fillValues(int currentDimensionIndex, int offset, List<Object> currentDimension, + private void fillValues(int currentDimensionIndex, int offset, List<Object> currentDimension, DimensionSizes sizes, double[] values) { if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension for (int i = 0; i < currentDimension.size(); i++) @@ -369,7 +369,7 @@ public class IndexedTensor implements Tensor { } } } - + private double nullAsZero(Double value) { if (value == null) return 0; return value; @@ -431,7 +431,7 @@ public class IndexedTensor implements Tensor { } } - + private final class CellIterator implements Iterator<Cell> { private int count = 0; @@ -451,7 +451,7 @@ public class IndexedTensor implements Tensor { reusedCell.value = get(indexes.toSourceValueIndex()); return reusedCell; } - + } private final class ValueIterator implements Iterator<Double> { @@ -474,25 +474,25 @@ public class IndexedTensor implements Tensor { } } - + private final class SuperspaceIterator implements Iterator<SubspaceIterator> { private final Indexes superindexes; /** Those indexes this should iterate over */ private final List<Integer> subdimensionIndexes; - - /** + + /** * The sizes of the space we'll return values of, one value for each dimension of this tensor, - * which may be equal to or smaller than the sizes of this tensor + * which may be equal to or smaller than the sizes of this tensor */ private final DimensionSizes iterateSizes; private int count = 0; - + private SuperspaceIterator(Set<String> superdimensionNames, DimensionSizes iterateSizes) { this.iterateSizes = iterateSizes; - + List<Integer> superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length) for (int i = type.dimensions().size() - 1; i >= 0; i-- ) { // iterate inner dimensions first @@ -501,10 +501,10 @@ public class IndexedTensor implements Tensor { else subdimensionIndexes.add(i); } - + superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, superdimensionIndexes); } - + @Override public boolean hasNext() { return count < superindexes.size(); @@ -527,7 +527,7 @@ public class IndexedTensor implements Tensor { */ public final class SubspaceIterator implements Iterator<Tensor.Cell> { - /** + /** * This iterator will iterate over the given dimensions, in the order given * (the first dimension index given is incremented to exhaustion first (i.e is etc.). * This may be any subset of the dimensions given by address and dimensionSizes. @@ -538,21 +538,21 @@ public class IndexedTensor implements Tensor { private Indexes indexes; private int count = 0; - + /** A lazy cell for reuse */ private final LazyCell reusedCell; - - /** + + /** * Creates a new subspace iterator - * + * * @param iterateDimensions the dimensions to iterate over, given as indexes in the dimension order of the * type of the tensor this iterates over. This iterator will iterate over these - * dimensions to exhaustion in the order given (the first dimension index given is + * dimensions to exhaustion in the order given (the first dimension index given is * incremented to exhaustion first (i.e is etc.), while other dimensions will be held * at a constant position. * This may be any subset of the dimensions given by address and dimensionSizes. * This is treated as immutable. - * @param address the address of the first cell of this subspace. + * @param address the address of the first cell of this subspace. */ private SubspaceIterator(List<Integer> iterateDimensions, int[] address, DimensionSizes iterateSizes) { this.iterateDimensions = iterateDimensions; @@ -561,26 +561,26 @@ public class IndexedTensor implements Tensor { this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); reusedCell = new LazyCell(indexes, Double.NaN); } - + /** Returns the total number of cells in this subspace */ - public int size() { + public int size() { return indexes.size(); } - + /** Returns the address of the cell this currently points to (which may be an invalid position) */ public TensorAddress address() { return indexes.toAddress(); } - + /** Rewind this iterator to the first element */ - public void reset() { + public void reset() { this.count = 0; - this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); + this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); } - + @Override public boolean hasNext() { - return count < indexes.size(); + return count < indexes.size(); } - + /** Returns the next cell, which is valid until next() is called again */ @Override public Cell next() { @@ -611,15 +611,15 @@ public class IndexedTensor implements Tensor { public TensorAddress getKey() { return indexes.toAddress(); } - + @Override public Double getValue() { return value; } } // TODO: Make dimensionSizes a class - - /** + + /** * An array of indexes into this tensor which are able to find the next index in the value order. * next() can be called once per element in the dimensions we iterate over. It must be called once * before accessing the first position. @@ -631,7 +631,7 @@ public class IndexedTensor implements Tensor { private final DimensionSizes iterationSizes; protected final int[] indexes; - + public static Indexes of(DimensionSizes sizes) { return of(sizes, sizes); } @@ -676,14 +676,14 @@ public class IndexedTensor implements Tensor { return new MultiDimensionIndexes(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, size); } } - + private static List<Integer> completeIterationOrder(int length) { List<Integer> iterationDimensions = new ArrayList<>(length); for (int i = 0; i < length; i++) iterationDimensions.add(length - 1 - i); return iterationDimensions; } - + private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, int[] indexes) { this.sourceSizes = sourceSizes; this.iterationSizes = iterationSizes; @@ -708,9 +708,9 @@ public class IndexedTensor implements Tensor { /** Returns a copy of the indexes of this which must not be modified */ public int[] indexesForReading() { return indexes; } - - int toSourceValueIndex() { - return IndexedTensor.toValueIndex(indexes, sourceSizes); + + int toSourceValueIndex() { + return IndexedTensor.toValueIndex(indexes, sourceSizes); } int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); } @@ -729,9 +729,9 @@ public class IndexedTensor implements Tensor { public String toString() { return "indexes " + Arrays.toString(indexes); } - + public abstract int size(); - + public abstract void next(); } @@ -763,18 +763,18 @@ public class IndexedTensor implements Tensor { public void next() {} } - + private static class MultiDimensionIndexes extends Indexes { private final int size; private final List<Integer> iterateDimensions; - + private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimensions = iterateDimensions; this.size = size; - + // Initialize to the (virtual) position before the first cell indexes[iterateDimensions.get(0)]--; } @@ -785,10 +785,10 @@ public class IndexedTensor implements Tensor { return size; } - /** - * Advances this to the next cell in the standard indexed tensor cell order. - * The first call to this will put it at the first position. - * + /** + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. + * * @throws RuntimeException if this is called more times than its size */ @Override @@ -802,12 +802,12 @@ public class IndexedTensor implements Tensor { } } - + /** In this case we can reuse the source index computation for the iteration index */ private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes { private int lastComputedSourceValueIndex = -1; - + private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { super(sizes, sizes, iterateDimensions, initialIndexes, size); } @@ -827,7 +827,7 @@ public class IndexedTensor implements Tensor { private final int size; private final int iterateDimension; - + /** Maintain this directly as an optimization for 1-d iteration */ private int currentSourceValueIndex, currentIterationValueIndex; @@ -847,7 +847,7 @@ public class IndexedTensor implements Tensor { currentSourceValueIndex = IndexedTensor.toValueIndex(indexes, sourceSizes); currentIterationValueIndex = IndexedTensor.toValueIndex(indexes, iterateSizes); } - + /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @Override public int size() { @@ -855,8 +855,8 @@ public class IndexedTensor implements Tensor { } /** - * Advances this to the next cell in the standard indexed tensor cell order. - * The first call to this will put it at the first position. + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. * * @throws RuntimeException if this is called more times than its size */ @@ -888,7 +888,7 @@ public class IndexedTensor implements Tensor { /** The iteration step in the value index space */ private final int step; - private EqualSizeSingleDimensionIndexes(DimensionSizes sizes, + private EqualSizeSingleDimensionIndexes(DimensionSizes sizes, int iterateDimension, int[] initialIndexes, int size) { super(sizes, sizes, initialIndexes); this.iterateDimension = iterateDimension; @@ -907,8 +907,8 @@ public class IndexedTensor implements Tensor { } /** - * Advances this to the next cell in the standard indexed tensor cell order. - * The first call to this will put it at the first position. + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. * * @throws RuntimeException if this is called more times than its size */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index aba61478e69..618bff0caae 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -27,7 +27,7 @@ public class MappedTensor implements Tensor { @Override public TensorType type() { return type; } - + @Override public int size() { return cells.size(); } @@ -56,16 +56,16 @@ public class MappedTensor implements Tensor { } public static class Builder implements Tensor.Builder { - + private final TensorType type; private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); - + public static Builder of(TensorType type) { return new Builder(type); } private Builder(TensorType type) { this.type = type; } - + public CellBuilder cell() { return new CellBuilder(type, this); } @@ -89,24 +89,24 @@ public class MappedTensor implements Tensor { public MappedTensor build() { return new MappedTensor(type, cells.build()); } - + } private static class CellIteratorAdaptor implements Iterator<Cell> { private final Iterator<Map.Entry<TensorAddress, Double>> adaptedIterator; - + private CellIteratorAdaptor(Iterator<Map.Entry<TensorAddress, Double>> adaptedIterator) { this.adaptedIterator = adaptedIterator; } - + @Override public boolean hasNext() { return adaptedIterator.hasNext(); } @Override public Cell next() { Map.Entry<TensorAddress, Double> entry = adaptedIterator.next(); - return new Cell(entry.getKey(), entry.getValue()); + return new Cell(entry.getKey(), entry.getValue()); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 9a751e078e0..79bb27fcd1b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -117,7 +117,7 @@ public class MixedTensor implements Tensor { return index.denseSubspaceSize(); } - + /** * Base class for building mixed tensors. */ @@ -286,7 +286,7 @@ public class MixedTensor implements Tensor { } return typeBuilder.build(); } - + } /** @@ -360,7 +360,7 @@ public class MixedTensor implements Tensor { } return denseSubspaceSize; } - + private TensorAddress sparsePartialAddress(TensorAddress address) { if (type.dimensions().size() != address.size()) { throw new IllegalArgumentException("Tensor type and address are not of same size."); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 1b60e01cf7e..2ed211539d8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -52,7 +52,7 @@ import java.util.function.Function; public interface Tensor { // ----------------- Accessors - + TensorType type(); /** Returns whether this have any cells */ @@ -70,13 +70,13 @@ public interface Tensor { /** Returns the values of this in some undefined order */ Iterator<Double> valueIterator(); - /** + /** * Returns an immutable map of the cells of this in no particular order. - * This may be expensive for some implementations - avoid when possible + * This may be expensive for some implementations - avoid when possible */ Map<TensorAddress, Double> cells(); - /** + /** * Returns the value of this as a double if it has no dimensions and one value * * @throws IllegalStateException if this does not have zero dimensions and one value @@ -87,9 +87,9 @@ public interface Tensor { if (size() == 0) return Double.NaN; return valueIterator().next(); } - + // ----------------- Primitive tensor functions - + default Tensor map(DoubleUnaryOperator mapper) { return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate(); } @@ -108,7 +108,7 @@ public interface Tensor { } default Tensor rename(String fromDimension, String toDimension) { - return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), + return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), Collections.singletonList(toDimension)).evaluate(); } @@ -123,13 +123,13 @@ public interface Tensor { default Tensor rename(List<String> fromDimensions, List<String> toDimensions) { return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate(); } - + static Tensor generate(TensorType type, Function<List<Integer>, Double> valueSupplier) { return new Generate(type, valueSupplier).evaluate(); } - + // ----------------- Composite tensor functions which have a defined primitive mapping - + default Tensor l1Normalize(String dimension) { return new L1Normalize(new ConstantTensor(this), dimension).evaluate(); } @@ -231,7 +231,7 @@ public interface Tensor { if (cellEntries.isEmpty()) return "{}"; return "{" + cellEntries.get(0).getValue() +"}"; } - + Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey()); StringBuilder b = new StringBuilder("{"); @@ -253,7 +253,7 @@ public interface Tensor { */ boolean equals(Object o); - /** + /** * Implement here to make this work across implementations. * Implementations must override equals and call this because this is an interface and cannot override equals. */ @@ -328,13 +328,13 @@ public interface Tensor { @Override public TensorAddress getKey() { return address; } - /** + /** * Returns the direct index which can be used to locate this cell, or -1 if not available. * This is for optimizations mapping between tensors where this is possible without creating a * TensorAddress. */ int getDirectIndex() { return -1; } - + @Override public Double getValue() { return value; } @@ -388,20 +388,20 @@ public interface Tensor { /** Returns the type this is building */ TensorType type(); - + /** Return a cell builder */ CellBuilder cell(); /** Add a cell */ Builder cell(TensorAddress address, double value); - + /** Add a cell */ Builder cell(double value, int ... labels); - /** - * Add a cell - * - * @param cell a cell providing the location at which to add this cell + /** + * Add a cell + * + * @param cell a cell providing the location at which to add this cell * @param value the value to assign to the cell */ default Builder cell(Cell cell, double value) { @@ -409,12 +409,12 @@ public interface Tensor { } Tensor build(); - + class CellBuilder { private final TensorAddress.Builder addressBuilder; private final Tensor.Builder tensorBuilder; - + CellBuilder(TensorType type, Tensor.Builder tensorBuilder) { addressBuilder = new TensorAddress.Builder(type); this.tensorBuilder = tensorBuilder; @@ -436,5 +436,5 @@ public interface Tensor { } } - + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index ff1202463f2..7161450d5d5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -32,10 +32,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { /** Returns the number of labels in this */ public abstract int size(); - + /** - * Returns the i'th label in this - * + * Returns the i'th label in this + * * @throws IllegalArgumentException if there is no label at this index */ public abstract String label(int i); @@ -102,23 +102,23 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { private StringTensorAddress(String ... labels) { this.labels = Arrays.copyOf(labels, labels.length); } - + @Override public int size() { return labels.length; } - + @Override public String label(int i) { return labels[i]; } - + @Override - public int intLabel(int i) { + public int intLabel(int i) { try { return Integer.parseInt(labels[i]); - } + } catch (NumberFormatException e) { throw new IllegalArgumentException("Expected an int label in " + this + " at position " + i); } } - + @Override public TensorAddress withLabel(int index, int label) { String[] labels = Arrays.copyOf(this.labels, this.labels.length); @@ -169,7 +169,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { private final TensorType type; private final String[] labels; - + public Builder(TensorType type) { this(type, new String[type.dimensions().size()]); } @@ -193,7 +193,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { labels[labelIndex.get()] = label; return this; } - + /** Creates a copy of this which can be modified separately */ public Builder copy() { return new Builder(type, Arrays.copyOf(labels, labels.length)); @@ -202,7 +202,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { public TensorAddress build() { for (int i = 0; i < labels.length; i++) if (labels[i] == null) - throw new IllegalArgumentException("Missing a value for dimension " + + throw new IllegalArgumentException("Missing a value for dimension " + type.dimensions().get(i).name() + " for " + type); return TensorAddress.of(labels); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 9b3a9328f07..da8ab3bb0ec 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -96,7 +96,7 @@ class TensorParser { if (valueEnd < 0) throw new IllegalArgumentException("A tensor string must end by '}'"); } - + TensorAddress address = addressBuilder.build(); Double value = asDouble(address, s.substring(0, valueEnd).trim()); builder.cell(address, value); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 914d853aeca..c05c35d6df3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -53,17 +53,14 @@ public class TensorType { return TensorTypeParser.fromSpec(specString); } - /** Returns the number of dimensions of this: dimensions().size() */ - public int rank() { return dimensions.size(); } - /** Returns an immutable list of the dimensions of this */ public List<Dimension> dimensions() { return dimensions; } - + /** Returns an immutable set of the names of the dimensions of this */ public Set<String> dimensionNames() { return dimensions.stream().map(Dimension::name).collect(Collectors.toSet()); } - + /** Returns the dimension with this name, or empty if not present */ public Optional<Dimension> dimension(String name) { return indexOfDimension(name).map(i -> dimensions.get(i)); @@ -77,7 +74,7 @@ public class TensorType { return Optional.empty(); } - /** + /** * Returns whether this type can be assigned to the given type, * i.e if the given type is a generalization of this type. */ @@ -131,9 +128,9 @@ public class TensorType { private final String name; - private Dimension(String name) { + private Dimension(String name) { Objects.requireNonNull(name, "A tensor name cannot be null"); - this.name = name; + this.name = name; } public final String name() { return name; } @@ -149,7 +146,7 @@ public class TensorType { /** Returns true if this is an indexed bound or unboun type */ public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; } - /** + /** * Returns the dimension resulting from combining two dimensions having the same name but possibly different * types. This works by degrading to the type making the fewer promises. * [N] + [M] = [min(N, M)] @@ -168,7 +165,7 @@ public class TensorType { IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get(); return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; } - + @Override public abstract String toString(); @@ -178,21 +175,21 @@ public class TensorType { if (other == null || getClass() != other.getClass()) return false; return name.equals(((Dimension)other).name); } - + @Override public int hashCode() { return name.hashCode(); } - + @Override public int compareTo(Dimension other) { return this.name.compareTo(other.name); } - + public static Dimension indexed(String name, int size) { return new IndexedBoundDimension(name, size); } - + } public static class IndexedBoundDimension extends TensorType.Dimension { @@ -292,9 +289,9 @@ public class TensorType { public Builder() { } - /** - * Creates a builder containing a combination of the dimensions of the given types - * + /** + * Creates a builder containing a combination of the dimensions of the given types + * * If the same dimension is indexed with different size restrictions the largest size will be used. * If it is size restricted in one argument but not the other it will not be size restricted. * If it is indexed in one and mapped in the other it will become mapped. @@ -328,12 +325,9 @@ public class TensorType { } } - /** Returns the current number of dimensions in this */ - public int rank() { return dimensions.size(); } - - /** + /** * Adds a new dimension to this - * + * * @throws IllegalArgumentException if the dimension is already present */ private Builder add(Dimension dimension) { @@ -352,7 +346,7 @@ public class TensorType { return this; } - /** + /** * Adds a bound indexed dimension to this * * @throws IllegalArgumentException if the dimension is already present @@ -361,7 +355,7 @@ public class TensorType { /** * Adds an unbound indexed dimension to this - * + * * @throws IllegalArgumentException if the dimension is already present */ public Builder indexed(String name) { @@ -381,7 +375,7 @@ public class TensorType { public Builder dimension(Dimension dimension) { return add(dimension); } - + /** Returns the given dimension, or empty if none is present */ public Optional<Dimension> getDimension(String dimension) { return Optional.ofNullable(dimensions.get(dimension)); @@ -399,7 +393,7 @@ public class TensorType { public TensorType build() { return new TensorType(dimensions.values()); } - + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java index 3db661f8a23..84caca78fb2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java @@ -2,17 +2,16 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; -import com.yahoo.tensor.Tensor; + +import java.util.HashMap; /** * An evaluation context which is passed down to all nested functions during evaluation. - * + * The default context is empty to allow various evaluation frameworks to support their own implementation. + * * @author bratseth */ @Beta public interface EvaluationContext { - /** Returns the tensor bound to this name, or null if none */ - Tensor getTensor(String name); - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index db8a66a5fa2..cf704c15f4f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -18,7 +18,7 @@ public class MapEvaluationContext implements EvaluationContext { public void put(String name, Tensor tensor) { bindings.put(name, tensor); } - @Override - public Tensor getTensor(String name) { return bindings.get(name); } + /** Returns the tensor bound to this name, or null if none */ + public Tensor get(String name) { return bindings.get(name); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 1f6ad050368..8ade181bdb7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -12,18 +12,18 @@ import java.util.List; /** * A tensor variable name which resolves to a tensor in the context at evaluation time - * + * * @author bratseth */ @Beta public class VariableTensor extends PrimitiveTensorFunction { private final String name; - + public VariableTensor(String name) { this.name = name; } - + @Override public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @@ -35,7 +35,7 @@ public class VariableTensor extends PrimitiveTensorFunction { @Override public Tensor evaluate(EvaluationContext context) { - return context.getTensor(name); + return ((MapEvaluationContext)context).get(name); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 191c7988443..8f4dbf014a7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -8,7 +8,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext; /** * A composite tensor function is a tensor function which can be expressed (less tersely) * as a tree of primitive tensor functions. - * + * * @author bratseth */ @Beta diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index faa0ca36cb6..1dbb94fdb20 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -15,7 +15,7 @@ import java.util.stream.Collectors; /** * Concatenation of two tensors along an (indexed) dimension - * + * * @author bratseth */ @Beta @@ -74,7 +74,7 @@ public class Concat extends PrimitiveTensorFunction { concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder); return builder.build(); } - + private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); @@ -112,7 +112,7 @@ public class Concat extends PrimitiveTensorFunction { Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build(); return tensor.multiply(unitTensor); } - + } /** Returns the type resulting from concatenating a and b */ @@ -144,7 +144,7 @@ public class Concat extends PrimitiveTensorFunction { /** * Combine two addresses, adding the offset to the concat dimension * - * @return the combined address or null if the addresses are incompatible + * @return the combined address or null if the addresses are incompatible * (in some other dimension than the concat dimension) */ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, @@ -161,7 +161,7 @@ public class Concat extends PrimitiveTensorFunction { /** * Returns the an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. - * That is, if the returned array contains n at index i then + * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 14ed38718ce..4ac7b21ba90 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -10,18 +10,18 @@ import java.util.List; /** * A function which returns a constant tensor. - * + * * @author bratseth */ @Beta public class ConstantTensor extends PrimitiveTensorFunction { private final Tensor constant; - + public ConstantTensor(String tensorString) { this.constant = Tensor.from(tensorString); } - + public ConstantTensor(Tensor tensor) { this.constant = tensor; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index c75d8ee4753..bbdbd5c3df1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -11,19 +11,19 @@ import java.util.stream.Stream; /** * A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere. - * + * * @author bratseth */ public class Diag extends CompositeTensorFunction { private final TensorType type; private final Function<List<Integer>, Double> diagFunction; - + public Diag(TensorType type) { this.type = type; this.diagFunction = ScalarFunctions.equal(dimensionNames().collect(Collectors.toList())); } - + @Override public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @@ -43,7 +43,7 @@ public class Diag extends CompositeTensorFunction { public String toString(ToStringContext context) { return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction; } - + private Stream<String> dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::name); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index e42d25197e2..6ea73b7f310 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -15,7 +15,7 @@ import java.util.function.Function; /** * An indexed tensor whose values are generated by a function - * + * * @author bratseth */ @Beta @@ -26,7 +26,7 @@ public class Generate extends PrimitiveTensorFunction { /** * Creates a generated tensor - * + * * @param type the type of the tensor * @param generator the function generating values from a list of ints specifying the indexes of the * tensor cell which will receive the value @@ -39,7 +39,7 @@ public class Generate extends PrimitiveTensorFunction { this.type = type; this.generator = generator; } - + private void validateType(TensorType type) { for (TensorType.Dimension dimension : type.dimensions()) if (dimension.type() != TensorType.Dimension.Type.indexedBound) @@ -58,7 +58,7 @@ public class Generate extends PrimitiveTensorFunction { @Override public PrimitiveTensorFunction toPrimitive() { return this; } - + @Override public Tensor evaluate(EvaluationContext context) { Tensor.Builder builder = Tensor.Builder.of(type); @@ -69,14 +69,14 @@ public class Generate extends PrimitiveTensorFunction { } return builder.build(); } - + private DimensionSizes dimensionSizes(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); for (int i = 0; i < b.dimensions(); i++) b.set(i, type.dimensions().get(i).size().get()); return b.build(); } - + @Override public String toString(ToStringContext context) { return type + "(" + generator + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index ff887e3e9a6..8c4dbfb0acb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -28,12 +28,12 @@ import java.util.function.DoubleBinaryOperator; * The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells * given by the cross product of the cells of the given tensors, having as values the value produced by * applying the given combinator function on the values from the two source cells. - * + * * @author bratseth */ @Beta public class Join extends PrimitiveTensorFunction { - + private final TensorFunction argumentA, argumentB; private final DoubleBinaryOperator combinator; @@ -46,30 +46,6 @@ public class Join extends PrimitiveTensorFunction { this.combinator = combinator; } - /** Returns the type resulting from applying Join to the two given types */ - public static TensorType outputType(TensorType a, TensorType b) { - TensorType.Builder typeBuilder = new TensorType.Builder(); - for (int i = 0; i < a.dimensions().size(); ++i) { - TensorType.Dimension aDim = a.dimensions().get(i); - for (int j = 0; j < b.dimensions().size(); ++j) { - TensorType.Dimension bDim = b.dimensions().get(j); - if (aDim.name().equals(bDim.name())) { // include - if (aDim.isIndexed() && bDim.isIndexed()) { - if (aDim.size().isPresent() || bDim.size().isPresent()) - typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE), - bDim.size().orElse(Integer.MAX_VALUE))); - else - typeBuilder.indexed(aDim.name()); - } - else { - typeBuilder.mapped(aDim.name()); - } - } - } - } - return typeBuilder.build(); - } - public TensorFunction argumentA() { return argumentA; } public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } @@ -112,11 +88,11 @@ public class Join extends PrimitiveTensorFunction { else return generalJoin(a, b, joinedType); } - + private boolean hasSingleIndexedDimension(Tensor tensor) { return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); } - + private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator<Double> aIterator = a.valueIterator(); @@ -138,7 +114,7 @@ public class Join extends PrimitiveTensorFunction { } return builder.build(); } - + /** Join a tensor into a superspace */ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) @@ -150,7 +126,7 @@ public class Join extends PrimitiveTensorFunction { private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build(); - + DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace); IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes); @@ -158,14 +134,14 @@ public class Join extends PrimitiveTensorFunction { // Find dimensions which are only in the supertype Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames()); superDimensionNames.removeAll(subspace.type().dimensionNames()); - + for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) { IndexedTensor.SubspaceIterator subspaceInSuper = i.next(); joinSubspaces(subspace.valueIterator(), subspace.size(), subspaceInSuper, subspaceInSuper.size(), reversedArgumentOrder, builder); } - + return builder.build(); } @@ -224,7 +200,7 @@ public class Join extends PrimitiveTensorFunction { subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get(); return subspaceIndexes; } - + private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { String[] subspaceLabels = new String[subspaceIndexes.length]; for (int i = 0; i < subspaceIndexes.length; i++) @@ -259,7 +235,7 @@ public class Join extends PrimitiveTensorFunction { DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize); // for each combination of dimensions only in a - for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) { + for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) { IndexedTensor.SubspaceIterator aSubspace = ia.next(); // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { @@ -276,7 +252,7 @@ public class Join extends PrimitiveTensorFunction { } } } - + private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); for (int i = 0; i < addressType.dimensions().size(); i++) @@ -284,7 +260,7 @@ public class Join extends PrimitiveTensorFunction { builder.add(addressType.dimensions().get(i).name(), address.intLabel(i)); return builder.build(); } - + /** Returns the sizes from the joined sizes which are present in the type argument */ private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) { DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); @@ -295,7 +271,7 @@ public class Join extends PrimitiveTensorFunction { } return builder.build(); } - + private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); @@ -364,7 +340,7 @@ public class Join extends PrimitiveTensorFunction { /** * Returns the an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. - * That is, if the returned array contains n at index i then + * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ @@ -384,7 +360,7 @@ public class Join extends PrimitiveTensorFunction { return TensorAddress.of(joinedLabels); } - /** + /** * Maps the content in the given list to the given array, using the given index map. * * @return true if the mapping was successful, false if one of the destination positions was diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index a5e1a016a41..a9872bb42d8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -6,7 +6,6 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Collections; @@ -33,8 +32,6 @@ public class Map extends PrimitiveTensorFunction { this.mapper = mapper; } - public static TensorType outputType(TensorType inputType) { return inputType; } - public TensorFunction argument() { return argument; } public DoubleUnaryOperator mapper() { return mapper; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 4071917c2b5..bb27e937699 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -3,7 +3,6 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; -import com.yahoo.tensor.TensorType; import java.util.List; @@ -15,17 +14,13 @@ public class Matmul extends CompositeTensorFunction { private final TensorFunction argument1, argument2; private final String dimension; - + public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) { this.argument1 = argument1; this.argument2 = argument2; this.dimension = dimension; } - public static TensorType outputType(TensorType a, TensorType b, String dimension) { - return Join.outputType(a, b); - } - @Override public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); } @@ -44,7 +39,7 @@ public class Matmul extends CompositeTensorFunction { Reduce.Aggregator.sum, dimension); } - + @Override public String toString(ToStringContext context) { return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java index b7c9a5d2342..efb7b9e500c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java @@ -8,10 +8,10 @@ import com.yahoo.tensor.Tensor; * A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions. * All tensor implementations must implement all primitive tensor functions. * Primitive tensor functions are fully inspectable. - * + * * @author bratseth */ @Beta public abstract class PrimitiveTensorFunction extends TensorFunction { - + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 958ef85d1dc..457763e97ba 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -22,11 +22,11 @@ import java.util.stream.Stream; public class Random extends CompositeTensorFunction { private final TensorType type; - + public Random(TensorType type) { this.type = type; } - + @Override public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @@ -46,7 +46,7 @@ public class Random extends CompositeTensorFunction { public String toString(ToStringContext context) { return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")"; } - + private Stream<String> dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::toString); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index a56f82b026a..e2b39a2048d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -12,19 +12,19 @@ import java.util.stream.Stream; /** * A tensor generator which returns a tensor of any dimension filled with the sum of the tensor * indexes of each position. - * + * * @author bratseth */ public class Range extends CompositeTensorFunction { private final TensorType type; private final Function<List<Integer>, Double> rangeFunction; - + public Range(TensorType type) { this.type = type; this.rangeFunction = ScalarFunctions.sum(dimensionNames().collect(Collectors.toList())); } - + @Override public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @@ -44,7 +44,7 @@ public class Range extends CompositeTensorFunction { public String toString(ToStringContext context) { return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction; } - + private Stream<String> dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::toString); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index de9f90a5804..cfc78be7e0c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -19,7 +19,7 @@ import java.util.Objects; import java.util.Set; /** - * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions + * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions * are collapsed to a single value using an aggregator function. * * @author bratseth @@ -45,7 +45,7 @@ public class Reduce extends PrimitiveTensorFunction { /** * Creates a reduce function. - * + * * @param argument the tensor to reduce * @param aggregator the aggregator function to use * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced, @@ -61,15 +61,6 @@ public class Reduce extends PrimitiveTensorFunction { this.dimensions = ImmutableList.copyOf(dimensions); } - public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) { - TensorType.Builder b = new TensorType.Builder(); - for (TensorType.Dimension dimension : inputType.dimensions()) { - if ( ! reduceDimensions.contains(dimension.name())) - b.dimension(dimension); - } - return b.build(); - } - public TensorFunction argument() { return argument; } @Override @@ -91,7 +82,7 @@ public class Reduce extends PrimitiveTensorFunction { public String toString(ToStringContext context) { return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; } - + private String commaSeparated(List<String> list) { StringBuilder b = new StringBuilder(); for (String element : list) @@ -103,7 +94,7 @@ public class Reduce extends PrimitiveTensorFunction { public Tensor evaluate(EvaluationContext context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) - throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + + throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor"); // Special case: Reduce all @@ -112,14 +103,14 @@ public class Reduce extends PrimitiveTensorFunction { return reduceIndexedVector((IndexedTensor)argument); else return reduceAllGeneral(argument); - + // Reduce type TensorType.Builder builder = new TensorType.Builder(); for (TensorType.Dimension dimension : argument.type().dimensions()) if ( ! dimensions.contains(dimension.name())) // keep builder.dimension(dimension); TensorType reducedType = builder.build(); - + // Reduce cells Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { @@ -131,10 +122,10 @@ public class Reduce extends PrimitiveTensorFunction { Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet()) reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); - + return reducedBuilder.build(); } - + private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) { Set<Integer> indexesToRemove = new HashSet<>(); for (String dimensionToRemove : this.dimensions) @@ -147,7 +138,7 @@ public class Reduce extends PrimitiveTensorFunction { reducedLabels[reducedLabelIndex++] = address.label(i); return TensorAddress.of(reducedLabels); } - + private Tensor reduceAllGeneral(Tensor argument) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); ) @@ -163,7 +154,7 @@ public class Reduce extends PrimitiveTensorFunction { } private static abstract class ValueAggregator { - + private static ValueAggregator ofType(Aggregator aggregator) { switch (aggregator) { case avg : return new AvgAggregator(); @@ -174,22 +165,22 @@ public class Reduce extends PrimitiveTensorFunction { case min : return new MinAggregator(); default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented"); } - + } /** Add a new value to those aggregated by this */ public abstract void aggregate(double value); - + /** Returns the value aggregated by this */ public abstract double aggregatedValue(); - + } - + private static class AvgAggregator extends ValueAggregator { private int valueCount = 0; private double valueSum = 0.0; - + @Override public void aggregate(double value) { valueCount++; @@ -197,7 +188,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public double aggregatedValue() { + public double aggregatedValue() { return valueSum / valueCount; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index ec9b762a41c..6b0daf1b49a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -3,6 +3,8 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -17,7 +19,7 @@ import java.util.Objects; /** * The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names. - * + * * @author bratseth */ @Beta @@ -27,10 +29,6 @@ public class Rename extends PrimitiveTensorFunction { private final List<String> fromDimensions; private final List<String> toDimensions; - public Rename(TensorFunction argument, String fromDimension, String toDimension) { - this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); - } - public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null"); @@ -44,7 +42,7 @@ public class Rename extends PrimitiveTensorFunction { this.fromDimensions = ImmutableList.copyOf(fromDimensions); this.toDimensions = ImmutableList.copyOf(toDimensions); } - + @Override public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } @@ -64,7 +62,7 @@ public class Rename extends PrimitiveTensorFunction { Map<String, String> fromToMap = fromToMap(); TensorType renamedType = rename(tensor.type(), fromToMap); - + // an array which lists the index of each label in the renamed type int[] toIndexes = new int[tensor.type().dimensions().size()]; for (int i = 0; i < tensor.type().dimensions().size(); i++) { @@ -72,7 +70,7 @@ public class Rename extends PrimitiveTensorFunction { String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName); toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get(); } - + Tensor.Builder builder = Tensor.Builder.of(renamedType); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); @@ -88,7 +86,7 @@ public class Rename extends PrimitiveTensorFunction { builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); return builder.build(); } - + private TensorAddress rename(TensorAddress address, int[] toIndexes) { String[] reorderedLabels = new String[toIndexes.length]; for (int i = 0; i < toIndexes.length; i++) @@ -97,18 +95,18 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public String toString(ToStringContext context) { - return "rename(" + argument.toString(context) + ", " + + public String toString(ToStringContext context) { + return "rename(" + argument.toString(context) + ", " + toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } - + private Map<String, String> fromToMap() { Map<String, String> map = new HashMap<>(); for (int i = 0; i < fromDimensions.size(); i++) map.put(fromDimensions.get(i), toDimensions.get(i)); return map; } - + private String toVectorString(List<String> elements) { if (elements.size() == 1) return elements.get(0); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index fb5029fbfd6..99f79cb735a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -21,87 +21,101 @@ import java.util.stream.Collectors; @Beta public class ScalarFunctions { - public static DoubleBinaryOperator add() { return new Add(); } - public static DoubleBinaryOperator divide() { return new Divide(); } + public static DoubleBinaryOperator add() { return new Addition(); } + public static DoubleBinaryOperator multiply() { return new Multiplication(); } + public static DoubleBinaryOperator divide() { return new Division(); } public static DoubleBinaryOperator equal() { return new Equal(); } - public static DoubleBinaryOperator multiply() { return new Multiply(); } - - public static DoubleUnaryOperator acos() { return new Acos(); } - public static DoubleUnaryOperator exp() { return new Exp(); } - public static DoubleUnaryOperator sqrt() { return new Sqrt(); } public static DoubleUnaryOperator square() { return new Square(); } - + public static DoubleUnaryOperator sqrt() { return new Sqrt(); } + public static DoubleUnaryOperator exp() { return new Exponent(); } public static Function<List<Integer>, Double> random() { return new Random(); } public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); } - // Binary operators ----------------------------------------------------------------------------- + public static class Addition implements DoubleBinaryOperator { - public static class Add implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left + right; } + @Override public String toString() { return "f(a,b)(a + b)"; } - } - public static class Equal implements DoubleBinaryOperator { - @Override - public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; } - @Override - public String toString() { return "f(a,b)(a==b)"; } } - public static class Exp implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return Math.exp(operand); } - @Override - public String toString() { return "f(a)(exp(a))"; } - } + public static class Multiplication implements DoubleBinaryOperator { - public static class Multiply implements DoubleBinaryOperator { @Override - public double applyAsDouble(double left, double right) { return left * right; } + public double applyAsDouble(double left, double right) { return left * right; } + @Override public String toString() { return "f(a,b)(a * b)"; } + } - public static class Divide implements DoubleBinaryOperator { + public static class Division implements DoubleBinaryOperator { + @Override public double applyAsDouble(double left, double right) { return left / right; } + @Override public String toString() { return "f(a,b)(a / b)"; } } - // Unary operators ------------------------------------------------------------------------------ + public static class Equal implements DoubleBinaryOperator { + + @Override + public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; } + + @Override + public String toString() { return "f(a,b)(a==b)"; } + } + + public static class Square implements DoubleUnaryOperator { - public static class Acos implements DoubleUnaryOperator { @Override - public double applyAsDouble(double operand) { return Math.acos(operand); } + public double applyAsDouble(double operand) { return operand * operand; } + @Override - public String toString() { return "f(a)(acos(a))"; } + public String toString() { return "f(a)(a * a)"; } + } public static class Sqrt implements DoubleUnaryOperator { + @Override public double applyAsDouble(double operand) { return Math.sqrt(operand); } + @Override public String toString() { return "f(a)(sqrt(a))"; } + } - public static class Square implements DoubleUnaryOperator { + public static class Exponent implements DoubleUnaryOperator { @Override - public double applyAsDouble(double operand) { return operand * operand; } + public double applyAsDouble(double operand) { return Math.exp(operand); } @Override - public String toString() { return "f(a)(a * a)"; } + public String toString() { return "f(a)(exp(a))"; } } - // Variable-length operators ----------------------------------------------------------------------------- + public static class Random implements Function<List<Integer>, Double> { + + @Override + public Double apply(List<Integer> values) { + return ThreadLocalRandom.current().nextDouble(); + } + + @Override + public String toString() { return "random"; } - public static class EqualElements implements Function<List<Integer>, Double> { - private final ImmutableList<String> argumentNames; + } + + public static class EqualElements implements Function<List<Integer>, Double> { + + private final ImmutableList<String> argumentNames; + private EqualElements(List<String> argumentNames) { this.argumentNames = ImmutableList.copyOf(argumentNames); } @@ -114,6 +128,7 @@ public class ScalarFunctions { return 0.0; return 1.0; } + @Override public String toString() { if (argumentNames.size() == 0) return "1"; @@ -128,19 +143,13 @@ public class ScalarFunctions { } return b.toString(); } - } - public static class Random implements Function<List<Integer>, Double> { - @Override - public Double apply(List<Integer> values) { - return ThreadLocalRandom.current().nextDouble(); - } - @Override - public String toString() { return "random"; } } public static class SumElements implements Function<List<Integer>, Double> { + private final ImmutableList<String> argumentNames; + private SumElements(List<String> argumentNames) { this.argumentNames = ImmutableList.copyOf(argumentNames); } @@ -152,10 +161,12 @@ public class ScalarFunctions { sum += value; return (double)sum; } + @Override public String toString() { return argumentNames.stream().collect(Collectors.joining("+")); } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index c856b548180..bf279eb24d8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -2,8 +2,6 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; -import com.google.common.collect.ImmutableList; -import com.yahoo.tensor.TensorType; import java.util.Collections; import java.util.List; @@ -21,10 +19,6 @@ public class Softmax extends CompositeTensorFunction { this.argument = argument; this.dimension = dimension; } - - public static TensorType outputType(TensorType inputType, String dimension) { - return Reduce.outputType(inputType, ImmutableList.of(dimension)); - } @Override public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 533a46f87fe..cabcce198d1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -12,7 +12,7 @@ import java.util.List; * A representation of a tensor function which is able to be translated to a set of primitive * tensor functions if necessary. * All tensor functions are immutable. - * + * * @author bratseth */ @Beta @@ -48,11 +48,11 @@ public abstract class TensorFunction { /** * Return a string representation of this context. - * + * * @param context a context which must be passed to all nexted functions when requesting the string value */ public abstract String toString(ToStringContext context); - + @Override public String toString() { return toString(ToStringContext.empty()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java index 416b28afa22..e8c425d49e0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java @@ -24,7 +24,7 @@ interface BinaryFormat { /** * Deserialize the given binary data into a Tensor object. - * + * * @param type the expected abstract type of the tensor to serialize, or empty to use type information from the data * @param buffer the buffer containing the tensor binary data */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index aabb53d1c67..8b7325ec211 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -16,9 +16,9 @@ import java.util.Optional; * * Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]* * Cell_values = [double, double, double, ...]* - * where values are encoded in order of increasing indexes in each dimension, increasing + * where values are encoded in order of increasing indexes in each dimension, increasing * indexes of later dimensions in the dimension type before earlier. - * + * * @author bratseth */ @Beta @@ -54,7 +54,7 @@ public class DenseBinaryFormat implements BinaryFormat { type = optionalType.get(); TensorType serializedType = decodeType(buffer); if ( ! serializedType.isAssignableTo(type)) - throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + + throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + " cannot be assigned to type " + type); sizes = sizesFromType(serializedType); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 01a1d023f2b..7467554790a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -46,9 +46,9 @@ public class TypedBinaryFormat { return result; } - /** - * Decode some data to a tensor - * + /** + * Decode some data to a tensor + * * @param type the type to decode and validate to, or empty to use the type given in the data * @param buffer the buffer containing the data, use GrowableByteByffer.wrap(byte[]) if you have a byte array * @return the resulting tensor diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index abdb3071bf7..d199dd3a876 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -13,14 +13,14 @@ import java.util.stream.Collectors; /** * Microbenchmark of tensor operations. - * + * * @author bratseth */ public class TensorFunctionBenchmark { private final static Random random = new Random(); - - public double benchmark(int iterations, List<Tensor> modelVectors, TensorType.Dimension.Type dimensionType, + + public double benchmark(int iterations, List<Tensor> modelVectors, TensorType.Dimension.Type dimensionType, boolean extraSpace) { Tensor queryVector = vectors(1, 300, dimensionType).get(0); if (extraSpace) { @@ -34,7 +34,7 @@ public class TensorFunctionBenchmark { long totalTime = System.currentTimeMillis() - startTime; return (double)totalTime / (double)iterations; } - + private Tensor unitVector(String dimension) { return Tensor.Builder.of(new TensorType.Builder().indexed(dimension, 1).build()) .cell().label(dimension, 0).value(1).build(); @@ -49,11 +49,11 @@ public class TensorFunctionBenchmark { private double dotProduct(Tensor tensor, List<Tensor> tensors) { double largest = Double.MIN_VALUE; - TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), - new VariableTensor("argument"), (a, b) -> a * b), + TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), + new VariableTensor("argument"), (a, b) -> a * b), Reduce.Aggregator.sum).toPrimitive(); MapEvaluationContext context = new MapEvaluationContext(); - + for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor context.put("argument", tensorElement); double dotProduct = dotProductFunction.evaluate(context).asDouble(); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 693b0f09351..30078b4a826 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -25,7 +25,7 @@ import static org.junit.Assert.fail; /** * Tests tensor functionality - * + * * @author bratseth */ public class TensorTestCase { @@ -108,7 +108,7 @@ public class TensorTestCase { Tensor.diag(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build())); assertEquals(Tensor.from("{ {x:1}:0, {x:3}:1, {x:9}:0 }"), Tensor.from("{ {x:1}:1, {x:3}:5, {x:9}:3 }").argmax("x")); } - + /** Test the same computation made in various ways which are implemented with special-case optimizations */ @Test public void testOptimizedComputation() { @@ -130,7 +130,7 @@ public class TensorTestCase { assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.indexedUnbound), vectors(Type.mapped, 2))); assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2))); assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2))); - + // Test the unoptimized path by joining in another dimension Tensor unitJ = Tensor.Builder.of(new TensorType.Builder().mapped("j").build()).cell().label("j", 0).value(1).build(); Tensor unitK = Tensor.Builder.of(new TensorType.Builder().mapped("k").build()).cell().label("k", 0).value(1).build(); @@ -138,7 +138,7 @@ public class TensorTestCase { Tensor matrixInKSpace = matrix(Type.mapped, 2).get(0).multiply(unitK); assertEquals("Generic computation implementation", 42, (int)dotProduct(vectorInJSpace, Collections.singletonList(matrixInKSpace))); } - + private double dotProduct(Tensor tensor, List<Tensor> tensors) { double sum = 0; TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), @@ -161,7 +161,7 @@ public class TensorTestCase { private Tensor vector(int vectorSize, TensorType.Dimension.Type dimensionType) { return vectors(vectorSize, dimensionType, 1).get(0); } - + /** Create a list of vectors having a single dimension x */ private List<Tensor> vectors(TensorType.Dimension.Type dimensionType, int vectorCount) { return vectors(3, dimensionType, vectorCount); @@ -179,8 +179,8 @@ public class TensorTestCase { } return tensors; } - - /** + + /** * Create a matrix of vectors (in dimension i) where each vector has the dimension x. * This matrix contains the same vectors as returned by createVectors, in a single list element for convenience. */ diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java index f11c068bd74..fab53218b2c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java @@ -10,12 +10,12 @@ import static org.junit.Assert.assertEquals; * @author bratseth */ public class JoinTestCase { - + /** Test the indexed subspace join optimization */ @Test public void testJoinIndexedSubspace() { Tensor t1, t2; - + t1 = Tensor.from("tensor(x[]):{{x:0}:1.0,{x:1}:2.0}"); t2 = Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:10,{x:1,y:1,z:0}:0.0}"); assertEquals(Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:20.0,{x:1,y:1,z:0}:0.0}"), @@ -34,10 +34,10 @@ public class JoinTestCase { assertEquals(Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:10.0,{x:1,y:1,z:0}:0.0}"), t2.divide(t1)); } - + @Test public void testGeneralJoin() { - assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"), + assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"), Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:4, {x:2}:6 }") .divide(Tensor.from("tensor(y[]):{{y:0}:2}"))); @@ -45,5 +45,5 @@ public class JoinTestCase { Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:6, {x:1,y:0}:8, {x:0,y:1}:20, {x:1,y:1}:24 }") .divide(Tensor.from("tensor(y[],z[]):{ {y:0,z:0}:2, {y:1,z:0}:4, {y:2,z:0}:6 }"))); } - + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java deleted file mode 100644 index 9643c0a56e7..00000000000 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java +++ /dev/null @@ -1,97 +0,0 @@ -package com.yahoo.tensor.functions; - -import com.google.common.collect.ImmutableList; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -/** - * @author bratseth - */ -public class MatmulTestCase { - - @Test - public void testMatmul2d() { - // d0 is the 'outermost' dimension, etc. - Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3])")); - ab.cell( 1,0, 0); - ab.cell( 2,0, 1); - ab.cell( 3,0, 2); - ab.cell( 4,1, 0); - ab.cell( 5,1, 1); - ab.cell( 6,1, 2); - Tensor a = ab.build(); - - Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[3],d1[2])")); - bb.cell( 7,0, 0); - bb.cell( 8,0, 1); - bb.cell( 9,1, 0); - bb.cell(10,1, 1); - bb.cell(11,2, 0); - bb.cell(12,2, 1); - Tensor b = bb.build(); - - Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2])")); - rb.cell( 58,0, 0); - rb.cell( 64,0, 1); - rb.cell(139,1, 0); - rb.cell(154,1, 1); - Tensor r = rb.build(); - - Tensor result = a.matmul(b.rename(ImmutableList.of("d0","d1"), ImmutableList.of("d1","d2")), "d1") - .rename("d2","d1"); - assertEquals(r, result); - } - - @Test - public void testMatmul3d() { - // Convention: a is the 'outermost' dimension, etc. - Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[3])")); - ab.cell( 1,0, 0, 0); - ab.cell( 2,0, 0, 1); - ab.cell( 3,0, 0, 2); - ab.cell( 4,0, 1, 0); - ab.cell( 5,0, 1, 1); - ab.cell( 6,0, 1, 2); - ab.cell( 7,1, 0, 0); - ab.cell( 8,1, 0, 1); - ab.cell( 9,1, 0, 2); - ab.cell(10,1, 1, 0); - ab.cell(11,1, 1, 1); - ab.cell(12,1, 1, 2); - Tensor a = ab.build(); - - Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3],d2[2])")); - bb.cell(13,0, 0, 0); - bb.cell(14,0, 0, 1); - bb.cell(15,0, 1, 0); - bb.cell(16,0, 1, 1); - bb.cell(17,0, 2, 0); - bb.cell(18,0, 2, 1); - bb.cell(19,1, 0, 0); - bb.cell(20,1, 0, 1); - bb.cell(21,1, 1, 0); - bb.cell(22,1, 1, 1); - bb.cell(23,1, 2, 0); - bb.cell(24,1, 2, 1); - Tensor b = bb.build(); - - Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[2])")); - rb.cell( 94,0, 0, 0); - rb.cell(100,0, 0, 1); - rb.cell(229,0, 1, 0); - rb.cell(244,0, 1, 1); - rb.cell(508,1, 0, 0); - rb.cell(532,1, 0, 1); - rb.cell(697,1, 1, 0); - rb.cell(730,1, 1, 1); - Tensor r = rb.build(); - - Tensor result = a.matmul(b.rename(ImmutableList.of("d1","d2"), ImmutableList.of("d2","d3")), "d2") - .rename("d3","d2"); - assertEquals(r, result); - } - -} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java index 55069eaced7..8a58cb0bbed 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java @@ -7,7 +7,7 @@ import static org.junit.Assert.assertEquals; /** * Tests translation of composite to primitive tensor function translation. - * + * * @author bratseth */ public class TensorFunctionTestCase { @@ -16,12 +16,12 @@ public class TensorFunctionTestCase { public void testTranslation() { assertTranslated("join({{x:1}:1.0}, reduce({{x:1}:1.0}, sum, x), f(a,b)(a / b))", new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x")); - assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))", + assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))", new Diag(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build())); assertTranslated("join({{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, reduce({{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, max, x), f(a,b)(a==b))", new Argmax(new ConstantTensor("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x")); } - + private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) { assertEquals(expectedTranslation, inputFunction.toPrimitive().toString()); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index 15a872e439f..349309a5052 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -30,7 +30,7 @@ public class DenseBinaryFormatTestCase { assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor(x[1],y[2],z[3]):{{y:0,x:0,z:0}:2.0}"); } - + @Test public void testSerializationToSeparateType() { assertSerialization(Tensor.from("tensor(x[1],y[1]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[],y[])")); @@ -64,7 +64,7 @@ public class DenseBinaryFormatTestCase { private void assertSerialization(Tensor tensor) { assertSerialization(tensor, tensor.type()); } - + private void assertSerialization(Tensor tensor, TensorType expectedType) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor)); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java index 33dfca017f4..b1d7d797b3e 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java @@ -84,7 +84,7 @@ public class MixedBinaryFormatTestCase { private void assertSerialization(Tensor tensor) { assertSerialization(tensor, tensor.type()); } - + private void assertSerialization(Tensor tensor, TensorType expectedType) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor)); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java index f002637847b..68bf59e3ed9 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java @@ -50,7 +50,7 @@ public class SerializationTestCase { JsonNode node = mapper.readTree(test); if (node.has("tensor") && node.has("binary")) { System.out.println("Running test: " + test); - + Tensor tensor = buildTensor(node.get("tensor")); String spec = getSpec(node.get("tensor")); byte[] encodedTensor = TypedBinaryFormat.encode(tensor); @@ -123,7 +123,7 @@ public class SerializationTestCase { private byte[] getBytes(String binaryRepresentation) { return parseHexValue(binaryRepresentation.substring(2)); } - + private byte[] parseHexValue(String s) { final int len = s.length(); byte[] bytes = new byte[len/2]; diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index f895b64379b..d17148cf8dc 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -65,7 +65,7 @@ public class SparseBinaryFormatTestCase { private void assertSerialization(Tensor tensor, TensorType expectedType) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), + Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor)); assertEquals(tensor, decodedTensor); } |