summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-09-19 15:03:41 +0200
committerLester Solbakken <lesters@oath.com>2019-09-19 15:03:41 +0200
commitb6e4d92c3996270ab9397a8c7a06eb909b88b6a7 (patch)
treeb3a6e15910d455a81149e481bacc63e445a8d690 /model-integration
parentd082531b8c6244de5bc99ed887f706be3a1084df (diff)
Avoid rename operations for tensorflow protobuf constants
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java10
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java2
4 files changed, 8 insertions, 16 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
index fc59ad35ef8..3ad5cb1d19f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
@@ -20,7 +20,6 @@ import java.util.Optional;
public class Const extends IntermediateOperation {
private final AttributeMap attributeMap;
- private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
public Const(String modelName,
String nodeName,
@@ -30,7 +29,6 @@ public class Const extends IntermediateOperation {
super(modelName, nodeName, inputs);
this.attributeMap = attributeMap;
this.type = type.rename(vespaName() + "_");
- standardNamingType = OrderedTensorType.standardType(type);
setConstantValue(value());
}
@@ -55,13 +53,7 @@ public class Const extends IntermediateOperation {
} else {
expressionNode = new ReferenceNode(Reference.simple("constant", vespaName()));
}
- TensorFunction output = new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
- if ( ! standardNamingType.equals(type)) {
- List<String> renameFrom = standardNamingType.dimensionNames();
- List<String> renameTo = type.dimensionNames();
- output = new Rename(output, renameFrom, renameTo);
- }
- return output;
+ return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java
index ecb67f93d69..f2c6dfd9069 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java
@@ -57,7 +57,7 @@ class AttributeConverter implements IntermediateOperation.AttributeMap {
if (attributeMap.containsKey(key)) {
AttrValue attrValue = attributeMap.get(key);
if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
- return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type.type())));
+ return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type)));
}
}
return get(key);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
index 6ab7a69e469..95727acb5b4 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
@@ -46,12 +46,12 @@ public class TensorConverter {
return builder.build();
}
- static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) {
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
+ static Tensor toVespaTensor(TensorProto tensorProto, OrderedTensorType type) {
Values values = readValuesOf(tensorProto);
- if (values.size() == 0) // Might be stored as "tensor_content" instead
- return toVespaTensor(readTensorContentOf(tensorProto));
-
+ if (values.size() == 0) { // Might be stored as "tensor_content" instead
+ return toVespaTensor(readTensorContentOf(tensorProto), type);
+ }
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type.type());
for (int i = 0; i < values.size(); ++i)
builder.cellByDirectIndex(i, values.get(i));
return builder.build();
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
index f38403bfbd4..b9d767774be 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
@@ -34,7 +34,7 @@ public class DropoutImportTestCase {
ImportedMlFunction function = signature.outputFunction("y", "y");
assertNotNull(function);
- assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(rename(constant(test_outputs_Const), d0, d1), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
+ assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
function.expression());
model.assertEqualResult("X", "outputs/Maximum");
assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString());