diff options
author | Lester Solbakken <lesters@yahoo-inc.com> | 2017-11-23 13:40:24 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@yahoo-inc.com> | 2017-11-23 13:40:24 +0100 |
commit | 64f750ad06fc2723479c50b1a32a2d1a4a0b4297 (patch) | |
tree | a46fa62917f455d5e3501dfe7213058578a8a54b /config-model/src | |
parent | c45fb657f2a10b0f9cb1f57b9a8c7d0c7919ad5e (diff) |
Make catch exception more specific in tensor transformer
Diffstat (limited to 'config-model/src')
3 files changed, 28 insertions, 18 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java index b2cd8574076..c75864f81b7 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java @@ -21,6 +21,8 @@ import java.util.Map; */ class ConstantTensorTransformer extends ExpressionTransformer { + public static final String CONSTANT = "constant"; + private final Map<String, Value> constants; private final Map<String, String> rankPropertiesOutput; @@ -64,7 +66,7 @@ class ConstantTensorTransformer extends ExpressionTransformer { return node; } TensorValue tensorValue = (TensorValue)value; - String featureName = "constant(" + node.getName() + ")"; + String featureName = CONSTANT + "(" + node.getName() + ")"; String tensorType = tensorValue.asTensor().type().toString(); rankPropertiesOutput.put(featureName + ".value", tensorValue.toString()); rankPropertiesOutput.put(featureName + ".type", tensorType); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java index 4dc4db9645a..69e353ceb35 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java @@ -20,6 +20,7 @@ import com.yahoo.tensor.functions.Reduce; import java.util.List; import java.util.Map; +import java.util.Optional; /** * Transforms and simplifies tensor expressions. @@ -81,28 +82,31 @@ public class TensorTransformer extends ExpressionTransformer { return node; } ExpressionNode arg1 = node.children().get(0); - ExpressionNode arg2 = node.children().get(1); - if (!potentialDimensionName(arg2)) { - return node; - } - try { - String dimension = ((ReferenceNode) arg2).getName(); - Context context = buildContext(arg1); - Value value = arg1.evaluate(context); - if (verifyTensorAndDimension(value, dimension)) { - return replaceMaxAndMinFunction(node); + Optional<String> dimension = dimensionName(node.children().get(1)); + if (dimension.isPresent()) { + try { + Context context = buildContext(arg1); + Value value = arg1.evaluate(context); + if (isTensorWithDimension(value, dimension.get())) { + return replaceMaxAndMinFunction(node); + } + } catch (IllegalArgumentException e) { + // Thrown from evaluate if some variables are not bound, for + // instance for a backend rank feature. Means we don't have + // enough information to replace expression. } - } catch (Exception e) { - // Don't replace the expression in case of any errors, e.g. unknown values or rank features } return node; } - private boolean potentialDimensionName(ExpressionNode arg) { - return arg instanceof ReferenceNode && ((ReferenceNode) arg).children().size() == 0; + private Optional<String> dimensionName(ExpressionNode arg) { + if (arg instanceof ReferenceNode && ((ReferenceNode)arg).children().size() == 0) { + return Optional.of(((ReferenceNode) arg).getName()); + } + return Optional.empty(); } - private boolean verifyTensorAndDimension(Value value, String dimension) { + private boolean isTensorWithDimension(Value value, String dimension) { if (value instanceof TensorValue) { Tensor tensor = ((TensorValue) value).asTensor(); TensorType type = tensor.type(); @@ -175,6 +179,9 @@ public class TensorTransformer extends ExpressionTransformer { } String attribute = node.children().get(0).toString(); Attribute a = search.getAttribute(attribute); + if (a == null) { + return; + } Value v; if (a.getType() == Attribute.Type.STRING) { v = emptyStringValue(); @@ -187,7 +194,7 @@ public class TensorTransformer extends ExpressionTransformer { } private void addIfConstant(ReferenceNode node, Context context) { - if (!node.getName().equals("constant")) { + if (!node.getName().equals(ConstantTensorTransformer.CONSTANT)) { return; } if (node.children().size() != 1) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index aa3fd4e9aae..12bdd8d2b5c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -63,7 +63,8 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { assertContainsExpression("min(attribute(double_field) + attribute(tensor_field_1),x)", "reduce(attribute(double_field)+attribute(tensor_field_1),min,x)"); assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)"); assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)"); - assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); + assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) + assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)"); } @Test |