aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-03-07 15:49:00 +0100
committerJon Bratseth <bratseth@oath.com>2018-03-07 15:49:00 +0100
commite3f5b2812b51896d4a4d5304e6e8c7060e60f68a (patch)
tree50f86020fe0b4c295bffd14bb877251293593f28 /searchlib/src/main
parent584dcf0b4a54b5e5a70696c15ee0c2bfe63ab656 (diff)
Allow macros to replace TenorFlow variables
Also, remove quoting of constant arguments generated in TensorFlow as that is unnecessary now and is interpreted as a string constant argument to a macro.
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java28
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java4
4 files changed, 14 insertions, 26 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
index 0fe73fad8ce..ee358f45b22 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
@@ -34,9 +34,7 @@ public class OperationMapper {
public static TensorFlowOperation get(NodeDef node, List<TensorFlowOperation> inputs, int port) {
switch (node.getOp().toLowerCase()) {
- /*
- * array ops
- */
+ // array ops
case "const": return new Const(node, inputs, port);
case "expanddims": return new ExpandDims(node, inputs, port);
case "identity": return new Identity(node, inputs, port);
@@ -46,15 +44,11 @@ public class OperationMapper {
case "shape": return new Shape(node, inputs, port);
case "squeeze": return new Squeeze(node, inputs, port);
- /*
- * control flow
- */
+ // control flow
case "merge": return new Merge(node, inputs, port);
case "switch": return new Switch(node, inputs, port);
- /*
- * math ops
- */
+ // math ops
case "add": return new Join(node, inputs, port, ScalarFunctions.add());
case "add_n": return new Join(node, inputs, port, ScalarFunctions.add());
case "acos": return new Map(node, inputs, port, ScalarFunctions.acos());
@@ -75,27 +69,17 @@ public class OperationMapper {
case "sub": return new Join(node, inputs, port, ScalarFunctions.subtract());
case "subtract": return new Join(node, inputs, port, ScalarFunctions.subtract());
- /*
- * nn ops
- */
+ // nn ops
case "biasadd": return new Join(node, inputs, port, ScalarFunctions.add());
case "elu": return new Map(node, inputs, port, ScalarFunctions.elu());
case "relu": return new Map(node, inputs, port, ScalarFunctions.relu());
case "selu": return new Map(node, inputs, port, ScalarFunctions.selu());
- /*
- * random ops
- */
-
- /*
- * state ops
- */
+ // state ops
case "variable": return new Variable(node, inputs, port);
case "variablev2": return new Variable(node, inputs, port);
- /*
- * evaluation no-ops
- */
+ // evaluation no-ops
case "stopgradient":return new Identity(node, inputs, port);
case "noop": return new NoOp(node, inputs, port);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
index 7decef51ab7..d06d7b48def 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -46,7 +47,7 @@ public class Const extends TensorFlowOperation {
if (type.type().rank() == 0 && getConstantValue().isPresent()) {
expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue());
} else {
- expressionNode = new ReferenceNode("constant(\"" + vespaName() + "\")");
+ expressionNode = new ReferenceNode(Reference.simple("constant", vespaName()));
}
return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
}
@@ -72,7 +73,7 @@ public class Const extends TensorFlowOperation {
private Value value() {
if (!node.getAttrMap().containsKey("value")) {
throw new IllegalArgumentException("Node '" + node.getName() + "' of type " +
- "const has missing 'value' attribute");
+ "const has missing 'value' attribute");
}
AttrValue attrValue = node.getAttrMap().get("value");
if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
index 9e8f6df3e2c..5d711aac100 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
@@ -67,7 +68,7 @@ public abstract class TensorFlowOperation {
public Optional<TensorFunction> function() {
if (function == null) {
if (isConstant()) {
- ExpressionNode constant = new ReferenceNode("constant(\"" + vespaName() + "\")");
+ ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
} else if (outputs.size() > 1) {
macro = lazyGetFunction();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index 822d6055815..639c5d22d9e 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -105,7 +106,8 @@ public final class ReferenceNode extends CompositeNode {
// TODO: Context should accept a Reference instead.
if (reference.isIdentifier())
return context.get(reference.name());
- return context.get(getName(), getArguments(), getOutput());
+ else
+ return context.get(getName(), getArguments(), getOutput());
}
@Override