diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-28 19:51:08 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-28 19:51:08 +0200 |
commit | a1c2cbf49a076bd659eb439da0c0298edc8fe224 (patch) | |
tree | 2f4fa936bf95a32f9c183e6ffcb432caa32e7973 /model-integration/src/main | |
parent | 50603040c505026b1431359016704b4f10e302f1 (diff) |
Rename constant dimensions
Diffstat (limited to 'model-integration/src/main')
5 files changed, 71 insertions, 38 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index 0f563a75b11..e9be35b6f84 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -75,7 +75,6 @@ public class DimensionRenamer { if (solution != null) return solution; for (RenameTarget target : prioritizedRenameTargets()) { - System.out.println("Trying rename " + target); target.insertRename(this); solution = solveWithOrWithoutSoftConstraints(maxIterations); if (solution != null) return solution; @@ -90,8 +89,9 @@ public class DimensionRenamer { if ( solution == null) { ListMap<Arc, Constraint> hardConstraints = new ListMap<>(); boolean anyRemoved = copyHard(constraints, hardConstraints); - if (anyRemoved) + if (anyRemoved) { solution = NamingConstraintSolver.solve(dimensions, hardConstraints, maxIterations); + } } return solution; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index d13c1ad5f3c..fc59ad35ef8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -10,6 +10,8 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; @@ -18,6 +20,7 @@ import java.util.Optional; public class Const extends IntermediateOperation { private final AttributeMap attributeMap; + private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... public Const(String modelName, String nodeName, @@ -27,6 +30,7 @@ public class Const extends IntermediateOperation { super(modelName, nodeName, inputs); this.attributeMap = attributeMap; this.type = type.rename(vespaName() + "_"); + standardNamingType = OrderedTensorType.standardType(type); setConstantValue(value()); } @@ -51,7 +55,13 @@ public class Const extends IntermediateOperation { } else { expressionNode = new ReferenceNode(Reference.simple("constant", vespaName())); } - return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode); + TensorFunction output = new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode); + if ( ! standardNamingType.equals(type)) { + List<String> renameFrom = standardNamingType.dimensionNames(); + List<String> renameTo = type.dimensionNames(); + output = new Rename(output, renameFrom, renameTo); + } + return output; } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index c3980b8fe93..9c9fed89585 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -102,12 +102,12 @@ public abstract class IntermediateOperation { /** Add dimension name constraints for this operation */ public void addDimensionNameConstraints(DimensionRenamer renamer) { } - /** Conveinence method to adds dimensions and constraints of the given tensor type */ + /** Convenience method to adds dimensions and constraints of the given tensor type */ protected void addConstraintsFrom(OrderedTensorType type, DimensionRenamer renamer) { for (int i = 0; i < type.dimensions().size(); i++) { renamer.addDimension(type.dimensions().get(i).name()); - // Each dimension is distinct: + // Each dimension is distinct and ordered correctly: for (int j = i + 1; j < type.dimensions().size(); j++) { renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), DimensionRenamer.Constraint.notEqual(false), @@ -216,29 +216,29 @@ public abstract class IntermediateOperation { return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); } - /** - * An interface mapping operation attributes to Vespa Values. - * Adapter for differences in different model types. - */ - public interface AttributeMap { - Optional<Value> get(String key); - Optional<Value> get(String key, OrderedTensorType type); - Optional<List<Value>> getList(String key); - } - public abstract String operationName(); @Override public String toString() { - return operationName() + + return operationName() + "(" + inputs().stream().map(input -> asString(input.type())).collect(Collectors.joining(", ")) + ")"; } public String toFullString() { - return "\t" + lazyGetType() + ":\t" + operationName() + + return "\t" + lazyGetType() + ":\t" + operationName() + "(" + inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) + ")"; } + /** + * An interface mapping operation attributes to Vespa Values. + * Adapter for differences in different model types. + */ + public interface AttributeMap { + Optional<Value> get(String key); + Optional<Value> get(String key, OrderedTensorType type); + Optional<List<Value>> getList(String key); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index 357794faee2..0e9c98b2b56 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -34,6 +34,7 @@ import org.tensorflow.framework.MetaGraphDef; import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.SignatureDef; import org.tensorflow.framework.TensorInfo; +import org.tensorflow.op.core.DecodeRaw; import java.io.IOException; import java.util.List; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java index 9cba388d00e..6ab7a69e469 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java @@ -5,15 +5,16 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import org.tensorflow.DataType; +import org.tensorflow.framework.DataType; import org.tensorflow.framework.TensorProto; +import org.tensorflow.framework.TensorShapeProto; import java.nio.ByteBuffer; import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; - +import java.util.List; /** * Converts TensorFlow tensors into Vespa tensors. @@ -48,9 +49,11 @@ public class TensorConverter { static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); Values values = readValuesOf(tensorProto); - for (int i = 0; i < values.size(); ++i) { + if (values.size() == 0) // Might be stored as "tensor_content" instead + return toVespaTensor(readTensorContentOf(tensorProto)); + + for (int i = 0; i < values.size(); ++i) builder.cellByDirectIndex(i, values.get(i)); - } return builder.build(); } @@ -74,28 +77,47 @@ public class TensorConverter { case UINT8: return new IntValues(tfTensor); case INT32: return new IntValues(tfTensor); case INT64: return new LongValues(tfTensor); + default: throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + + tfTensor.dataType() + " to a Vespa tensor"); } - throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + - tfTensor.dataType() + " to a Vespa tensor"); } private static Values readValuesOf(TensorProto tensorProto) { switch (tensorProto.getDtype()) { - case DT_BOOL: - return new ProtoBoolValues(tensorProto); - case DT_HALF: - return new ProtoHalfValues(tensorProto); - case DT_INT16: - case DT_INT32: - return new ProtoIntValues(tensorProto); - case DT_INT64: - return new ProtoInt64Values(tensorProto); - case DT_FLOAT: - return new ProtoFloatValues(tensorProto); - case DT_DOUBLE: - return new ProtoDoubleValues(tensorProto); - } - throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); + case DT_BOOL: return new ProtoBoolValues(tensorProto); + case DT_HALF: return new ProtoHalfValues(tensorProto); + case DT_INT16: case DT_INT32: return new ProtoIntValues(tensorProto); + case DT_INT64: return new ProtoInt64Values(tensorProto); + case DT_FLOAT: return new ProtoFloatValues(tensorProto); + case DT_DOUBLE: return new ProtoDoubleValues(tensorProto); + default: throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); + } + } + + private static Class dataTypeToClass(DataType dataType) { + switch (dataType) { + case DT_BOOL: return Boolean.class; + case DT_INT16: return Short.class; + case DT_INT32: return Integer.class; + case DT_INT64: return Long.class; + case DT_HALF: return Float.class; + case DT_FLOAT: return Float.class; + case DT_DOUBLE: return Double.class; + default: throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); + } + } + + private static org.tensorflow.Tensor readTensorContentOf(TensorProto tensorProto) { + return org.tensorflow.Tensor.create(dataTypeToClass(tensorProto.getDtype()), + asSizeArray(tensorProto.getTensorShape().getDimList()), + tensorProto.getTensorContent().asReadOnlyByteBuffer()); + } + + private static long[] asSizeArray(List<TensorShapeProto.Dim> dimensions) { + long[] sizes = new long[dimensions.size()]; + for (int i = 0; i < dimensions.size(); i++) + sizes[i] = dimensions.get(i).getSize(); + return sizes; } /** Allows reading values from buffers of various numeric types as bytes */ |