diff options
Diffstat (limited to 'model-integration')
4 files changed, 8 insertions, 16 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index b092b292627..fc895b07d53 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -20,7 +20,6 @@ import java.util.Optional; public class Const extends IntermediateOperation { private final AttributeMap attributeMap; - private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... public Const(String modelName, String nodeName, @@ -30,7 +29,6 @@ public class Const extends IntermediateOperation { super(modelName, nodeName, inputs); this.attributeMap = attributeMap; this.type = type.rename(vespaName() + "_"); - standardNamingType = OrderedTensorType.standardType(type); setConstantValue(value()); } @@ -55,13 +53,7 @@ public class Const extends IntermediateOperation { } else { expressionNode = new ReferenceNode(Reference.simple("constant", vespaName())); } - TensorFunction output = new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode); - if ( ! standardNamingType.equals(type)) { - List<String> renameFrom = standardNamingType.dimensionNames(); - List<String> renameTo = type.dimensionNames(); - output = new Rename(output, renameFrom, renameTo); - } - return output; + return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode); } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java index ecb67f93d69..f2c6dfd9069 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java @@ -57,7 +57,7 @@ class AttributeConverter implements IntermediateOperation.AttributeMap { if (attributeMap.containsKey(key)) { AttrValue attrValue = attributeMap.get(key); if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { - return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type.type()))); + return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type))); } } return get(key); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java index 6ab7a69e469..95727acb5b4 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java @@ -46,12 +46,12 @@ public class TensorConverter { return builder.build(); } - static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) { - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); + static Tensor toVespaTensor(TensorProto tensorProto, OrderedTensorType type) { Values values = readValuesOf(tensorProto); - if (values.size() == 0) // Might be stored as "tensor_content" instead - return toVespaTensor(readTensorContentOf(tensorProto)); - + if (values.size() == 0) { // Might be stored as "tensor_content" instead + return toVespaTensor(readTensorContentOf(tensorProto), type); + } + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type.type()); for (int i = 0; i < values.size(); ++i) builder.cellByDirectIndex(i, values.get(i)); return builder.build(); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java index f38403bfbd4..b9d767774be 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java @@ -34,7 +34,7 @@ public class DropoutImportTestCase { ImportedMlFunction function = signature.outputFunction("y", "y"); assertNotNull(function); - assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(rename(constant(test_outputs_Const), d0, d1), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", + assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", function.expression()); model.assertEqualResult("X", "outputs/Maximum"); assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString()); |