summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-01-31 11:13:51 +0100
committerJon Bratseth <bratseth@oath.com>2018-01-31 11:13:51 +0100
commita44edeba9f38c38c431d7b9b6e1ac454e2a0e610 (patch)
tree21600936cfe396492965764911652b49b4c22731 /searchlib
parent9c4ba9bf5b96b8c62a9b8c5a6c20a9175c698b70 (diff)
Verify macros
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java12
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java35
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java27
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java13
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java9
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java7
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java5
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java5
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java5
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java13
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java4
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java7
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java5
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java14
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java10
32 files changed, 162 insertions, 135 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
index a4d3c111356..5f8daa69ecf 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.TensorType;
import java.util.Arrays;
@@ -81,7 +82,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
}
@Override
- public ValueType getType(String name) {
+ public TensorType getType(String name) {
Integer index = nameToIndex().get(name);
if (index == null) return null;
return values[index].type();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index a1e79df95e3..861f9565d66 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -4,7 +4,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Set;
@@ -25,21 +24,10 @@ public abstract class Context implements EvaluationContext {
*/
public abstract Value get(String name);
- /** Returns the type of the value of the given variable as a tensor type, or null if there is no such variable */
- @Override
- public TensorType getTensorType(String name) {
- ValueType type = getType(name);
- if (type == null) return null;
- return type.tensorType();
- }
-
/** Returns a variable as a tensor */
@Override
public Tensor getTensor(String name) { return get(name).asTensor(); }
- /** Returns the type of the value of the given variable, or null if there is no such variable */
- public abstract ValueType getType(String name);
-
/**
* <p>Returns the value of a <i>structured variable</i> on the form
* <code>name(argument*)(.output)?</code>, where <i>argument</i> is any
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index c85a8f1c7e1..3ac11cff0cb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
/**
* A value which acts as a double in numerical context.
@@ -13,7 +14,7 @@ import com.yahoo.tensor.Tensor;
public abstract class DoubleCompatibleValue extends Value {
@Override
- public ValueType type() { return ValueType.doubleType(); }
+ public TensorType type() { return TensorType.empty; }
@Override
public boolean hasDouble() { return true; }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
index 34cd75df9cb..0625e8506cc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.TensorType;
/**
* A variant of an array context variant which supports faster binding of variables but slower lookup
@@ -67,7 +68,7 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
}
@Override
- public ValueType getType(String name) { return ValueType.doubleType(); }
+ public TensorType getType(String name) { return TensorType.empty; }
/** Perform a slow lookup by name */
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
index 2672fe6cd8e..39efe641f26 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation;
+import com.yahoo.tensor.TensorType;
+
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -13,7 +15,7 @@ import java.util.Set;
*/
public class MapContext extends Context {
- private Map<String,Value> bindings=new HashMap<>();
+ private Map<String, Value> bindings = new HashMap<>();
private boolean frozen = false;
@@ -21,16 +23,6 @@ public class MapContext extends Context {
}
/**
- * Freezes this.
- * Returns this for convenience.
- */
- public MapContext freeze() {
- if ( ! frozen)
- bindings = Collections.unmodifiableMap(bindings);
- return this;
- }
-
- /**
* Creates a map context from a map.
* The ownership of the map is transferred to this - it cannot be further modified by the caller.
* All the Values of the map will be frozen.
@@ -41,27 +33,32 @@ public class MapContext extends Context {
boundValue.freeze();
}
+ /**
+ * Freezes this.
+ * Returns this for convenience.
+ */
+ public MapContext freeze() {
+ if ( ! frozen)
+ bindings = Collections.unmodifiableMap(bindings);
+ return this;
+ }
+
/** Returns the type of the given value key, or null if it is not bound. */
@Override
- public ValueType getType(String key) {
+ public TensorType getType(String key) {
Value value = bindings.get(key);
if (value == null) return null;
return value.type();
}
- /**
- * Returns the value of a key. 0 is returned if the given key is not bound in this.
- */
+ /** Returns the value of a key. 0 is returned if the given key is not bound in this. */
@Override
public Value get(String key) {
return bindings.getOrDefault(key, DoubleValue.zero);
}
/**
- * Sets the value of a key.
- * The value is frozen by this.
- *
- * @since 5.1.5
+ * Sets the value of a key.The value is frozen by this.
*/
@Override
public void put(String key,Value value) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index 874b41ec3e1..c60507310f1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -5,6 +5,7 @@ import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
/**
* A string value.
@@ -29,7 +30,7 @@ public class StringValue extends Value {
}
@Override
- public ValueType type() { return ValueType.doubleType(); }
+ public TensorType type() { return TensorType.empty; }
/** Returns the hashcode of this, to enable strings to be encoded (with reasonable safely) as doubles for optimization */
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index f1c65dc79d3..c6e456f285d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -6,6 +6,7 @@ import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* A Value containing a tensor.
@@ -25,7 +26,7 @@ public class TensorValue extends Value {
}
@Override
- public ValueType type() { return ValueType.doubleType(); }
+ public TensorType type() { return TensorType.empty; }
@Override
public double asDouble() {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java
new file mode 100644
index 00000000000..f2c4ca58f6d
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java
@@ -0,0 +1,27 @@
+package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A context which only contains type information.
+ *
+ * @author bratseth
+ */
+public class TypeMapContext implements TypeContext {
+
+ private final Map<String, TensorType> featureTypes = new HashMap<>();
+
+ public void setType(String name, TensorType type) {
+ featureTypes.put(name, type);
+ }
+
+ @Override
+ public TensorType getType(String name) {
+ return featureTypes.get(name);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 856bfb3638d..59d2d95b879 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -13,14 +13,14 @@ import com.yahoo.tensor.TensorType;
* Concrete subclasses of this provides implementations of these methods or throws
* UnsupportedOperationException if the operation is not supported.
*
- * @author bratseth
+ * @author bratseth
*/
public abstract class Value {
private boolean frozen=false;
/** Returns the type of this value */
- public abstract ValueType type();
+ public abstract TensorType type();
/** Returns this value as a double, or throws UnsupportedOperationException if it cannot be represented as a double */
public abstract double asDouble();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java
deleted file mode 100644
index 046ad7861ef..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java
+++ /dev/null
@@ -1,32 +0,0 @@
-package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-import com.yahoo.tensor.TensorType;
-
-/**
- * The type of a ranking expression value - either a double or a tensor.
- *
- * @author bratseth
- */
-public class ValueType {
-
- private static final ValueType doubleValueType = new ValueType(TensorType.empty);
-
- private final TensorType tensorType;
-
- private ValueType(TensorType tensorType) {
- this.tensorType = tensorType;
- }
-
- /** Returns true if this is the double type */
- public boolean isDouble() { return tensorType.rank() == 0; }
-
- /** The type of this as a tensor type. The double type is the empty tensor type (rank 0) */
- public TensorType tensorType() { return tensorType; }
-
- /** Returns the type representing a double */
- public static ValueType doubleType() { return doubleValueType; }
-
- /** Returns a type representing the given tensor type */
- public static ValueType of(TensorType type) { return new ValueType(type); }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
index b4e126f69e0..8ee4cdbf297 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
@@ -4,10 +4,11 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Deque;
@@ -25,7 +26,7 @@ public class GBDTForestNode extends ExpressionNode {
}
@Override
- public final ValueType type(Context context) { return ValueType.doubleType(); }
+ public final TensorType type(TypeContext context) { return TensorType.empty; }
@Override
public final Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
index f085194a7df..aac635b2545 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
@@ -4,10 +4,11 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Deque;
@@ -50,7 +51,7 @@ public final class GBDTNode extends ExpressionNode {
public final double[] values() { return values; }
@Override
- public final ValueType type(Context context) { return ValueType.doubleType(); }
+ public final TensorType type(TypeContext context) { return TensorType.empty; }
@Override
public final Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index 816ef38e128..cdcb4df0360 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -147,14 +147,12 @@ class OperationMapper {
return operation.get().map(params);
}
params.signature().importWarning("TensorFlow operation '" + params.node().getOp() +
- "' in node '" + params.node().getName() + "' is not supported.");
+ "' in node '" + params.node().getName() + "' is not supported.");
return Optional.empty();
}
- /*
- * Operations
- */
+ // Operations ---------------------------------
private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) {
Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value");
@@ -209,10 +207,11 @@ class OperationMapper {
TensorType type = params.result().arguments().get(name);
if (type == null) {
throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name +
- "', but there is no such placeholder");
+ "', but there is no such placeholder");
}
+ params.result().requiredMacro(name, type);
// Included literally in the expression and so must be produced by a separate macro in the rank profile
- TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name));
+ TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name, type));
return Optional.of(output);
}
@@ -227,7 +226,7 @@ class OperationMapper {
}
private static Optional<TypedTensorFunction> reshape(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 2)) {
+ if ( ! checkInputs(params, 2)) {
return Optional.empty();
}
List<Optional<TypedTensorFunction>> inputs = params.inputs();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 3a6b3f23a1d..6d78b501fdc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -14,7 +14,6 @@ import org.tensorflow.framework.TensorInfo;
import java.io.File;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -109,10 +108,10 @@ public class TensorFlowImporter {
}
Optional<TypedTensorFunction> function = OperationMapper.map(params);
- if (!function.isPresent()) {
+ if ( ! function.isPresent()) {
return Optional.empty();
}
- if (!controlDependenciesArePresent(params)) {
+ if ( ! controlDependenciesArePresent(params)) {
return Optional.empty();
}
params.imported().put(nodeName, function.get());
@@ -185,6 +184,7 @@ public class TensorFlowImporter {
/** Parameter object to hold important data while importing */
static final class Parameters {
+
private final TensorFlowImporter owner;
private final GraphDef graph;
private final SavedModelBundle model;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
index 60aaf8ddce1..fe725e50a3f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
@@ -27,11 +27,13 @@ public class TensorFlowModel {
private final Map<String, Tensor> constants = new HashMap<>();
private final Map<String, RankingExpression> expressions = new HashMap<>();
private final Map<String, RankingExpression> macros = new HashMap<>();
+ private final Map<String, TensorType> requiredMacros = new HashMap<>();
void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
void constant(String name, Tensor constant) { constants.put(name, constant); }
void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
void macro(String name, RankingExpression expression) { macros.put(name, expression); }
+ void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
/** Returns the given signature. If it does not already exist it is added to this. */
Signature signature(String name) {
@@ -51,11 +53,12 @@ public class TensorFlowModel {
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
- /**
- * Returns an immutable map of expressions that can be overridden - such as PlaceholderWithDefault/
- */
+ /** Returns an immutable map of macros that are part of this model */
public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); }
+ /** Returns an immutable map of the macros that must be provided by the environment running this model */
+ public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); }
+
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
index d45037b6044..fc6428a4c33 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
@@ -4,8 +4,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Join;
import java.util.ArrayDeque;
@@ -80,14 +80,14 @@ public final class ArithmeticNode extends CompositeNode {
}
@Override
- public ValueType type(Context context) {
+ public TensorType type(TypeContext context) {
// Compute type using tensor types as arithmetic operators are supported on tensors
// and is correct also in the special case of doubles.
// As all our functions are type-commutative, we don't need to take operator precedence into account
- TensorType type = children.get(0).type(context).tensorType();
+ TensorType type = children.get(0).type(context);
for (int i = 1; i < children.size(); i++)
- type = Join.outputType(type, children.get(i).type(context).tensorType());
- return ValueType.of(type);
+ type = Join.outputType(type, children.get(i).type(context));
+ return type;
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
index fdbb22093ea..7601c0e6180 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
@@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.Deque;
@@ -48,8 +49,8 @@ public class ComparisonNode extends BooleanNode {
}
@Override
- public ValueType type(Context context) {
- return ValueType.doubleType(); // by definition
+ public TensorType type(TypeContext context) {
+ return TensorType.empty; // by definition
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
index e6074a5f745..1ea8d03f0eb 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
@@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Deque;
@@ -48,7 +49,7 @@ public final class ConstantNode extends ExpressionNode {
}
@Override
- public ValueType type(Context context) { return value.type(); }
+ public TensorType type(TypeContext context) { return value.type(); }
@Override
public Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
index 8404226c33b..fd9fab99db8 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
@@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
@@ -49,7 +50,7 @@ public final class EmbracedNode extends CompositeNode {
}
@Override
- public ValueType type(Context context) {
+ public TensorType type(TypeContext context) {
return value.type(context);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
index 5d06a562b5d..477f4db4981 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
@@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.io.Serializable;
import java.util.Deque;
@@ -47,7 +48,7 @@ public abstract class ExpressionNode implements Serializable {
* @param context the variable type bindings to use for this evaluation
* @throws IllegalArgumentException if there are variables which are not bound in the given map
*/
- public abstract ValueType type(Context context);
+ public abstract TensorType type(TypeContext context);
/**
* Returns the value of evaluating this expression over the given context.
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
index b187b8f029c..79515229019 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
@@ -4,7 +4,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Join;
import java.util.ArrayList;
@@ -66,16 +67,16 @@ public final class FunctionNode extends CompositeNode {
}
@Override
- public ValueType type(Context context) {
+ public TensorType type(TypeContext context) {
if (arguments.expressions().size() == 0)
- return ValueType.doubleType();
+ return TensorType.empty;
- ValueType argument1Type = arguments.expressions().get(0).type(context);
+ TensorType argument1Type = arguments.expressions().get(0).type(context);
if (arguments.expressions().size() == 1)
return argument1Type;
- ValueType argument2Type = arguments.expressions().get(1).type(context);
- return ValueType.of(Join.outputType(argument1Type.tensorType(), argument2Type.tensorType()));
+ TensorType argument2Type = arguments.expressions().get(1).type(context);
+ return Join.outputType(argument1Type, argument2Type);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
index fcd40bed4d0..e42884ecc05 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
@@ -4,8 +4,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
@@ -48,7 +48,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
}
@Override
- public ValueType type(Context context) { return ValueType.of(type); }
+ public TensorType type(TypeContext context) { return type; }
/** Evaluate this in a context which must have the arguments bound */
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
index b9866bec027..076df327044 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
@@ -3,9 +3,13 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
/**
* A conditional branch of a ranking expression.
@@ -71,9 +75,9 @@ public final class IfNode extends CompositeNode {
}
@Override
- public ValueType type(Context context) {
- ValueType trueType = trueExpression.type(context);
- ValueType falseType = falseExpression.type(context);
+ public TensorType type(TypeContext context) {
+ TensorType trueType = trueExpression.type(context);
+ TensorType falseType = falseExpression.type(context);
if ( ! trueType.equals(falseType))
throw new IllegalArgumentException("An if expression must produce a value of the same type in both " +
"alternatives, but the 'true' type is " + trueType + " while the " +
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index b898529c4b9..da946228291 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -5,7 +5,8 @@ import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
@@ -56,8 +57,8 @@ public class LambdaFunctionNode extends CompositeNode {
}
@Override
- public ValueType type(Context context) {
- return ValueType.doubleType(); // by definition - no nested lambdas
+ public TensorType type(TypeContext context) {
+ return TensorType.empty; // by definition - no nested lambdas
}
/** Evaluate this in a context which must have the arguments bound */
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
index cf6475238c4..f55ed59b65c 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
@@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Deque;
@@ -31,7 +32,7 @@ public final class NameNode extends ExpressionNode {
}
@Override
- public ValueType type(Context context) { throw new RuntimeException("Named nodes can not have a type"); }
+ public TensorType type(TypeContext context) { throw new RuntimeException("Named nodes can not have a type"); }
@Override
public Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
index 2e685a6c8ab..9cbe5f98c72 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
@@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
@@ -37,7 +38,7 @@ public class NegativeNode extends CompositeNode {
}
@Override
- public ValueType type(Context context) {
+ public TensorType type(TypeContext context) {
return value.type(context);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
index c4b940f1bd6..e7041600635 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
@@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
@@ -37,7 +38,7 @@ public class NotNode extends BooleanNode {
}
@Override
- public ValueType type(Context context) {
+ public TensorType type(TypeContext context) {
return value.type(context);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index e5176f9966d..f79297f7773 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -5,7 +5,8 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayDeque;
import java.util.Deque;
@@ -45,9 +46,8 @@ public final class ReferenceNode extends CompositeNode {
return new ReferenceNode(name, arguments, output);
}
- public String getOutput() {
- return output;
- }
+ /** Returns the specific output this references, or null if none specified */
+ public String getOutput() { return output; }
/** Returns a copy of this node with a modified output */
public ReferenceNode setOutput(String output) {
@@ -106,7 +106,7 @@ public final class ReferenceNode extends CompositeNode {
}
@Override
- public ValueType type(Context context) {
+ public TensorType type(TypeContext context) {
// Don't support outputs of different type, for simplicity
return context.getType(name);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
index a8b82c560f7..a7b82f4753f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
@@ -6,8 +6,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.Deque;
@@ -59,8 +60,8 @@ public class SetMembershipNode extends BooleanNode {
}
@Override
- public ValueType type(Context context) {
- return ValueType.doubleType();
+ public TensorType type(TypeContext context) {
+ return TensorType.empty;
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 97cfa2a5350..e4c381972e9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -5,10 +5,10 @@ import com.google.common.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
@@ -64,7 +64,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public ValueType type(Context context) { return ValueType.of(function.type(context)); }
+ public TensorType type(TypeContext context) { return function.type(context); }
@Override
public Value evaluate(Context context) {
@@ -111,8 +111,8 @@ public class TensorFunctionNode extends CompositeNode {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public TensorType type(EvaluationContext context) {
- return expression.type((Context)context).tensorType();
+ public TensorType type(TypeContext context) {
+ return expression.type(context);
}
@Override
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index b59b4750911..445ccf231a7 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -2,10 +2,12 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.TensorType;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
/**
* @author lesters
@@ -15,6 +17,18 @@ public class DropoutImportTestCase {
@Test
public void testDropoutImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved");
+
+ // Check (provided) macros
+ assertEquals(1, model.get().macros().size());
+ assertTrue(model.get().macros().containsKey("training/input"));
+ assertEquals("constant(\"training/input\")", model.get().macros().get("training/input").getRoot().toString());
+
+ // Check required macros
+ assertEquals(1, model.get().requiredMacros().size());
+ assertTrue(model.get().requiredMacros().containsKey("X"));
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
+ model.get().requiredMacros().get("X"));
+
TensorFlowModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index f12b9a2c628..01dd15d5fa0 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -8,6 +8,7 @@ import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
/**
* @author bratseth
@@ -33,6 +34,15 @@ public class MnistSoftmaxImportTestCase {
constant1.type());
assertEquals(10, constant1.size());
+ // Check (provided) macros
+ assertEquals(0, model.get().macros().size());
+
+ // Check required macros
+ assertEquals(1, model.get().requiredMacros().size());
+ assertTrue(model.get().requiredMacros().containsKey("Placeholder"));
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
+ model.get().requiredMacros().get("Placeholder"));
+
// Check signatures
assertEquals(1, model.get().signatures().size());
TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");