summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@yahoo-inc.com>2017-11-23 13:40:24 +0100
committerLester Solbakken <lesters@yahoo-inc.com>2017-11-23 13:40:24 +0100
commit64f750ad06fc2723479c50b1a32a2d1a4a0b4297 (patch)
treea46fa62917f455d5e3501dfe7213058578a8a54b /config-model
parentc45fb657f2a10b0f9cb1f57b9a8c7d0c7919ad5e (diff)
Make catch exception more specific in tensor transformer
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java39
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java3
3 files changed, 28 insertions, 18 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java
index b2cd8574076..c75864f81b7 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java
@@ -21,6 +21,8 @@ import java.util.Map;
*/
class ConstantTensorTransformer extends ExpressionTransformer {
+ public static final String CONSTANT = "constant";
+
private final Map<String, Value> constants;
private final Map<String, String> rankPropertiesOutput;
@@ -64,7 +66,7 @@ class ConstantTensorTransformer extends ExpressionTransformer {
return node;
}
TensorValue tensorValue = (TensorValue)value;
- String featureName = "constant(" + node.getName() + ")";
+ String featureName = CONSTANT + "(" + node.getName() + ")";
String tensorType = tensorValue.asTensor().type().toString();
rankPropertiesOutput.put(featureName + ".value", tensorValue.toString());
rankPropertiesOutput.put(featureName + ".type", tensorType);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java
index 4dc4db9645a..69e353ceb35 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java
@@ -20,6 +20,7 @@ import com.yahoo.tensor.functions.Reduce;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
/**
* Transforms and simplifies tensor expressions.
@@ -81,28 +82,31 @@ public class TensorTransformer extends ExpressionTransformer {
return node;
}
ExpressionNode arg1 = node.children().get(0);
- ExpressionNode arg2 = node.children().get(1);
- if (!potentialDimensionName(arg2)) {
- return node;
- }
- try {
- String dimension = ((ReferenceNode) arg2).getName();
- Context context = buildContext(arg1);
- Value value = arg1.evaluate(context);
- if (verifyTensorAndDimension(value, dimension)) {
- return replaceMaxAndMinFunction(node);
+ Optional<String> dimension = dimensionName(node.children().get(1));
+ if (dimension.isPresent()) {
+ try {
+ Context context = buildContext(arg1);
+ Value value = arg1.evaluate(context);
+ if (isTensorWithDimension(value, dimension.get())) {
+ return replaceMaxAndMinFunction(node);
+ }
+ } catch (IllegalArgumentException e) {
+ // Thrown from evaluate if some variables are not bound, for
+ // instance for a backend rank feature. Means we don't have
+ // enough information to replace expression.
}
- } catch (Exception e) {
- // Don't replace the expression in case of any errors, e.g. unknown values or rank features
}
return node;
}
- private boolean potentialDimensionName(ExpressionNode arg) {
- return arg instanceof ReferenceNode && ((ReferenceNode) arg).children().size() == 0;
+ private Optional<String> dimensionName(ExpressionNode arg) {
+ if (arg instanceof ReferenceNode && ((ReferenceNode)arg).children().size() == 0) {
+ return Optional.of(((ReferenceNode) arg).getName());
+ }
+ return Optional.empty();
}
- private boolean verifyTensorAndDimension(Value value, String dimension) {
+ private boolean isTensorWithDimension(Value value, String dimension) {
if (value instanceof TensorValue) {
Tensor tensor = ((TensorValue) value).asTensor();
TensorType type = tensor.type();
@@ -175,6 +179,9 @@ public class TensorTransformer extends ExpressionTransformer {
}
String attribute = node.children().get(0).toString();
Attribute a = search.getAttribute(attribute);
+ if (a == null) {
+ return;
+ }
Value v;
if (a.getType() == Attribute.Type.STRING) {
v = emptyStringValue();
@@ -187,7 +194,7 @@ public class TensorTransformer extends ExpressionTransformer {
}
private void addIfConstant(ReferenceNode node, Context context) {
- if (!node.getName().equals("constant")) {
+ if (!node.getName().equals(ConstantTensorTransformer.CONSTANT)) {
return;
}
if (node.children().size() != 1) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index aa3fd4e9aae..12bdd8d2b5c 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -63,7 +63,8 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
assertContainsExpression("min(attribute(double_field) + attribute(tensor_field_1),x)", "reduce(attribute(double_field)+attribute(tensor_field_1),min,x)");
assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)");
assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)");
- assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)");
+ assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...)
+ assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)");
}
@Test