diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-05-25 16:00:24 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-05-25 16:00:24 +0200 |
commit | b299a201c4ffa1c22476f93a08288d1abb97f744 (patch) | |
tree | c564877aa0a8c8056bd3050689ebe1c256a1b4e3 | |
parent | 076e30bf57da0be0f5e6162c43bdf1e2224ba668 (diff) |
More explanation on type mismatch
6 files changed, 22 insertions, 20 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 5790a5294eb..41da32f64c3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -127,7 +127,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil addGeneratedMacros(model, profile); reduceBatchDimensions(expression, model, profile, queryProfiles); - model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v)); + model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); store.writeConverted(expression); return expression.getRoot(); @@ -215,8 +215,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); if ( ! macroType.equals(constantValue.type())) throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + - "The required type of this is " + constantValue.type() + - ", but the macro returns " + macroType); + typeMismatchExplanation(constantValue.type(), macroType)); constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later } else { @@ -228,7 +227,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - private void transformGeneratedMacro(ModelStore store, RankProfile profile, + private void transformGeneratedMacro(ModelStore store, Set<String> constantsReplacedByMacros, String macroName, RankingExpression expression) { @@ -267,7 +266,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil RankProfile.Macro macro = profile.getMacros().get(macroName); if (macro == null) - throw new IllegalArgumentException("Model refers Placeholder '" + macroName + + throw new IllegalArgumentException("Model refers placeholder '" + macroName + "' of type " + requiredType + " but this macro is not present in " + profile); // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second @@ -276,18 +275,23 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil // type verification TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); if ( actualType == null) - throw new IllegalArgumentException("Model refers Placeholder '" + macroName + + throw new IllegalArgumentException("Model refers placeholder '" + macroName + "' of type " + requiredType + " which must be produced by a macro in the rank profile, but " + "this macro references a feature which is not declared"); if ( ! actualType.isAssignableTo(requiredType)) - throw new IllegalArgumentException("Model refers Placeholder '" + macroName + - "' of type " + requiredType + - " which must be produced by a macro in the rank profile, but " + - "this macro produces type " + actualType); + throw new IllegalArgumentException("Model refers placeholder '" + macroName + "'. " + + typeMismatchExplanation(requiredType, actualType)); } } + private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { + return "The required type of this is " + requiredType + ", but this macro returns " + actualType + + (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " + + "in query profile types - see the documentation." + : ""); + } + /** * Add the generated macros to the rank profile */ diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 623f26a6b27..55754605843 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -162,7 +162,7 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved'): " + - "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + + "Model refers placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); } @@ -179,8 +179,8 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved'): " + - "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " + - "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])", + "Model refers placeholder 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + + "but this macro returns tensor(d0[2],d5[10])", Exceptions.toMessageString(expected)); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 4ec23f98fc5..e3c72830095 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -255,7 +255,7 @@ public class TensorFlowImporter { } catch (ParseException e) { throw new RuntimeException("Tensorflow function " + function + - " cannot be parsed as a ranking expression", e); + " cannot be parsed as a ranking expression", e); } } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java index a394662800e..4f5d61d75f9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java @@ -28,12 +28,12 @@ public class ConcatV2 extends TensorFlowOperation { TensorFlowOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input if (!concatDimOp.getConstantValue().isPresent()) { throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + - "concat dimension must be a constant."); + "concat dimension must be a constant."); } Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor(); if (concatDimTensor.type().rank() != 0) { throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + - "concat dimension must be a scalar."); + "concat dimension must be a scalar."); } OrderedTensorType aType = inputs.get(0).type().get(); @@ -45,7 +45,7 @@ public class ConcatV2 extends TensorFlowOperation { OrderedTensorType bType = inputs.get(i).type().get(); if (bType.rank() != aType.rank()) { throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + - "inputs must have save rank."); + "inputs must have save rank."); } for (int j = 0; j < aType.rank(); ++j) { long dimSizeA = aType.dimensions().get(j).size().orElse(-1L); @@ -54,7 +54,7 @@ public class ConcatV2 extends TensorFlowOperation { concatDimSize += dimSizeB; } else if (dimSizeA != dimSizeB) { throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + - "input dimension " + j + " differs in input tensors."); + "input dimension " + j + " differs in input tensors."); } } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java index eb4b615b434..1619c11427a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java @@ -53,5 +53,4 @@ public class Placeholder extends TensorFlowOperation { return false; } - } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java index f74d1d6cb75..65ce7f00e34 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; import org.tensorflow.framework.NodeDef; |