summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-07-28 19:51:08 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-07-28 19:51:08 +0200
commita1c2cbf49a076bd659eb439da0c0298edc8fe224 (patch)
tree2f4fa936bf95a32f9c183e6ffcb432caa32e7973 /model-integration/src/main
parent50603040c505026b1431359016704b4f10e302f1 (diff)
Rename constant dimensions
Diffstat (limited to 'model-integration/src/main')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java28
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java1
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java64
5 files changed, 71 insertions, 38 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index 0f563a75b11..e9be35b6f84 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -75,7 +75,6 @@ public class DimensionRenamer {
if (solution != null) return solution;
for (RenameTarget target : prioritizedRenameTargets()) {
- System.out.println("Trying rename " + target);
target.insertRename(this);
solution = solveWithOrWithoutSoftConstraints(maxIterations);
if (solution != null) return solution;
@@ -90,8 +89,9 @@ public class DimensionRenamer {
if ( solution == null) {
ListMap<Arc, Constraint> hardConstraints = new ListMap<>();
boolean anyRemoved = copyHard(constraints, hardConstraints);
- if (anyRemoved)
+ if (anyRemoved) {
solution = NamingConstraintSolver.solve(dimensions, hardConstraints, maxIterations);
+ }
}
return solution;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
index d13c1ad5f3c..fc59ad35ef8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
@@ -10,6 +10,8 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
@@ -18,6 +20,7 @@ import java.util.Optional;
public class Const extends IntermediateOperation {
private final AttributeMap attributeMap;
+ private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
public Const(String modelName,
String nodeName,
@@ -27,6 +30,7 @@ public class Const extends IntermediateOperation {
super(modelName, nodeName, inputs);
this.attributeMap = attributeMap;
this.type = type.rename(vespaName() + "_");
+ standardNamingType = OrderedTensorType.standardType(type);
setConstantValue(value());
}
@@ -51,7 +55,13 @@ public class Const extends IntermediateOperation {
} else {
expressionNode = new ReferenceNode(Reference.simple("constant", vespaName()));
}
- return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
+ TensorFunction output = new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
+ if ( ! standardNamingType.equals(type)) {
+ List<String> renameFrom = standardNamingType.dimensionNames();
+ List<String> renameTo = type.dimensionNames();
+ output = new Rename(output, renameFrom, renameTo);
+ }
+ return output;
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index c3980b8fe93..9c9fed89585 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -102,12 +102,12 @@ public abstract class IntermediateOperation {
/** Add dimension name constraints for this operation */
public void addDimensionNameConstraints(DimensionRenamer renamer) { }
- /** Conveinence method to adds dimensions and constraints of the given tensor type */
+ /** Convenience method to adds dimensions and constraints of the given tensor type */
protected void addConstraintsFrom(OrderedTensorType type, DimensionRenamer renamer) {
for (int i = 0; i < type.dimensions().size(); i++) {
renamer.addDimension(type.dimensions().get(i).name());
- // Each dimension is distinct:
+ // Each dimension is distinct and ordered correctly:
for (int j = i + 1; j < type.dimensions().size(); j++) {
renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(),
DimensionRenamer.Constraint.notEqual(false),
@@ -216,29 +216,29 @@ public abstract class IntermediateOperation {
return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
}
- /**
- * An interface mapping operation attributes to Vespa Values.
- * Adapter for differences in different model types.
- */
- public interface AttributeMap {
- Optional<Value> get(String key);
- Optional<Value> get(String key, OrderedTensorType type);
- Optional<List<Value>> getList(String key);
- }
-
public abstract String operationName();
@Override
public String toString() {
- return operationName() +
+ return operationName() + "(" +
inputs().stream().map(input -> asString(input.type())).collect(Collectors.joining(", ")) +
")";
}
public String toFullString() {
- return "\t" + lazyGetType() + ":\t" + operationName() +
+ return "\t" + lazyGetType() + ":\t" + operationName() + "(" +
inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) +
")";
}
+ /**
+ * An interface mapping operation attributes to Vespa Values.
+ * Adapter for differences in different model types.
+ */
+ public interface AttributeMap {
+ Optional<Value> get(String key);
+ Optional<Value> get(String key, OrderedTensorType type);
+ Optional<List<Value>> getList(String key);
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
index 357794faee2..0e9c98b2b56 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
@@ -34,6 +34,7 @@ import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;
+import org.tensorflow.op.core.DecodeRaw;
import java.io.IOException;
import java.util.List;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
index 9cba388d00e..6ab7a69e469 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
@@ -5,15 +5,16 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import org.tensorflow.DataType;
+import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
+import org.tensorflow.framework.TensorShapeProto;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
-
+import java.util.List;
/**
* Converts TensorFlow tensors into Vespa tensors.
@@ -48,9 +49,11 @@ public class TensorConverter {
static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
Values values = readValuesOf(tensorProto);
- for (int i = 0; i < values.size(); ++i) {
+ if (values.size() == 0) // Might be stored as "tensor_content" instead
+ return toVespaTensor(readTensorContentOf(tensorProto));
+
+ for (int i = 0; i < values.size(); ++i)
builder.cellByDirectIndex(i, values.get(i));
- }
return builder.build();
}
@@ -74,28 +77,47 @@ public class TensorConverter {
case UINT8: return new IntValues(tfTensor);
case INT32: return new IntValues(tfTensor);
case INT64: return new LongValues(tfTensor);
+ default: throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
+ tfTensor.dataType() + " to a Vespa tensor");
}
- throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
- tfTensor.dataType() + " to a Vespa tensor");
}
private static Values readValuesOf(TensorProto tensorProto) {
switch (tensorProto.getDtype()) {
- case DT_BOOL:
- return new ProtoBoolValues(tensorProto);
- case DT_HALF:
- return new ProtoHalfValues(tensorProto);
- case DT_INT16:
- case DT_INT32:
- return new ProtoIntValues(tensorProto);
- case DT_INT64:
- return new ProtoInt64Values(tensorProto);
- case DT_FLOAT:
- return new ProtoFloatValues(tensorProto);
- case DT_DOUBLE:
- return new ProtoDoubleValues(tensorProto);
- }
- throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
+ case DT_BOOL: return new ProtoBoolValues(tensorProto);
+ case DT_HALF: return new ProtoHalfValues(tensorProto);
+ case DT_INT16: case DT_INT32: return new ProtoIntValues(tensorProto);
+ case DT_INT64: return new ProtoInt64Values(tensorProto);
+ case DT_FLOAT: return new ProtoFloatValues(tensorProto);
+ case DT_DOUBLE: return new ProtoDoubleValues(tensorProto);
+ default: throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
+ }
+ }
+
+ private static Class dataTypeToClass(DataType dataType) {
+ switch (dataType) {
+ case DT_BOOL: return Boolean.class;
+ case DT_INT16: return Short.class;
+ case DT_INT32: return Integer.class;
+ case DT_INT64: return Long.class;
+ case DT_HALF: return Float.class;
+ case DT_FLOAT: return Float.class;
+ case DT_DOUBLE: return Double.class;
+ default: throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
+ }
+ }
+
+ private static org.tensorflow.Tensor readTensorContentOf(TensorProto tensorProto) {
+ return org.tensorflow.Tensor.create(dataTypeToClass(tensorProto.getDtype()),
+ asSizeArray(tensorProto.getTensorShape().getDimList()),
+ tensorProto.getTensorContent().asReadOnlyByteBuffer());
+ }
+
+ private static long[] asSizeArray(List<TensorShapeProto.Dim> dimensions) {
+ long[] sizes = new long[dimensions.size()];
+ for (int i = 0; i < dimensions.size(); i++)
+ sizes[i] = dimensions.get(i).getSize();
+ return sizes;
}
/** Allows reading values from buffers of various numeric types as bytes */