diff options
Diffstat (limited to 'searchlib')
32 files changed, 162 insertions, 135 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java index a4d3c111356..5f8daa69ecf 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.TensorType; import java.util.Arrays; @@ -81,7 +82,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { } @Override - public ValueType getType(String name) { + public TensorType getType(String name) { Integer index = nameToIndex().get(name); if (index == null) return null; return values[index].type(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index a1e79df95e3..861f9565d66 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -4,7 +4,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Set; @@ -25,21 +24,10 @@ public abstract class Context implements EvaluationContext { */ public abstract Value get(String name); - /** Returns the type of the value of the given variable as a tensor type, or null if there is no such variable */ - @Override - public TensorType getTensorType(String name) { - ValueType type = getType(name); - if (type == null) return null; - return type.tensorType(); - } - /** Returns a variable as a tensor */ @Override public Tensor getTensor(String name) { return get(name).asTensor(); } - /** Returns the type of the value of the given variable, or null if there is no such variable */ - public abstract ValueType getType(String name); - /** * <p>Returns the value of a <i>structured variable</i> on the form * <code>name(argument*)(.output)?</code>, where <i>argument</i> is any diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index c85a8f1c7e1..3ac11cff0cb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; /** * A value which acts as a double in numerical context. @@ -13,7 +14,7 @@ import com.yahoo.tensor.Tensor; public abstract class DoubleCompatibleValue extends Value { @Override - public ValueType type() { return ValueType.doubleType(); } + public TensorType type() { return TensorType.empty; } @Override public boolean hasDouble() { return true; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java index 34cd75df9cb..0625e8506cc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.TensorType; /** * A variant of an array context variant which supports faster binding of variables but slower lookup @@ -67,7 +68,7 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { } @Override - public ValueType getType(String name) { return ValueType.doubleType(); } + public TensorType getType(String name) { return TensorType.empty; } /** Perform a slow lookup by name */ @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index 2672fe6cd8e..39efe641f26 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation; +import com.yahoo.tensor.TensorType; + import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -13,7 +15,7 @@ import java.util.Set; */ public class MapContext extends Context { - private Map<String,Value> bindings=new HashMap<>(); + private Map<String, Value> bindings = new HashMap<>(); private boolean frozen = false; @@ -21,16 +23,6 @@ public class MapContext extends Context { } /** - * Freezes this. - * Returns this for convenience. - */ - public MapContext freeze() { - if ( ! frozen) - bindings = Collections.unmodifiableMap(bindings); - return this; - } - - /** * Creates a map context from a map. * The ownership of the map is transferred to this - it cannot be further modified by the caller. * All the Values of the map will be frozen. @@ -41,27 +33,32 @@ public class MapContext extends Context { boundValue.freeze(); } + /** + * Freezes this. + * Returns this for convenience. + */ + public MapContext freeze() { + if ( ! frozen) + bindings = Collections.unmodifiableMap(bindings); + return this; + } + /** Returns the type of the given value key, or null if it is not bound. */ @Override - public ValueType getType(String key) { + public TensorType getType(String key) { Value value = bindings.get(key); if (value == null) return null; return value.type(); } - /** - * Returns the value of a key. 0 is returned if the given key is not bound in this. - */ + /** Returns the value of a key. 0 is returned if the given key is not bound in this. */ @Override public Value get(String key) { return bindings.getOrDefault(key, DoubleValue.zero); } /** - * Sets the value of a key. - * The value is frozen by this. - * - * @since 5.1.5 + * Sets the value of a key.The value is frozen by this. */ @Override public void put(String key,Value value) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index 874b41ec3e1..c60507310f1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -5,6 +5,7 @@ import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; /** * A string value. @@ -29,7 +30,7 @@ public class StringValue extends Value { } @Override - public ValueType type() { return ValueType.doubleType(); } + public TensorType type() { return TensorType.empty; } /** Returns the hashcode of this, to enable strings to be encoded (with reasonable safely) as doubles for optimization */ @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index f1c65dc79d3..c6e456f285d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -6,6 +6,7 @@ import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; /** * A Value containing a tensor. @@ -25,7 +26,7 @@ public class TensorValue extends Value { } @Override - public ValueType type() { return ValueType.doubleType(); } + public TensorType type() { return TensorType.empty; } @Override public double asDouble() { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java new file mode 100644 index 00000000000..f2c4ca58f6d --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java @@ -0,0 +1,27 @@ +package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.HashMap; +import java.util.Map; + +/** + * A context which only contains type information. + * + * @author bratseth + */ +public class TypeMapContext implements TypeContext { + + private final Map<String, TensorType> featureTypes = new HashMap<>(); + + public void setType(String name, TensorType type) { + featureTypes.put(name, type); + } + + @Override + public TensorType getType(String name) { + return featureTypes.get(name); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 856bfb3638d..59d2d95b879 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -13,14 +13,14 @@ import com.yahoo.tensor.TensorType; * Concrete subclasses of this provides implementations of these methods or throws * UnsupportedOperationException if the operation is not supported. * - * @author bratseth + * @author bratseth */ public abstract class Value { private boolean frozen=false; /** Returns the type of this value */ - public abstract ValueType type(); + public abstract TensorType type(); /** Returns this value as a double, or throws UnsupportedOperationException if it cannot be represented as a double */ public abstract double asDouble(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java deleted file mode 100644 index 046ad7861ef..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java +++ /dev/null @@ -1,32 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -import com.yahoo.tensor.TensorType; - -/** - * The type of a ranking expression value - either a double or a tensor. - * - * @author bratseth - */ -public class ValueType { - - private static final ValueType doubleValueType = new ValueType(TensorType.empty); - - private final TensorType tensorType; - - private ValueType(TensorType tensorType) { - this.tensorType = tensorType; - } - - /** Returns true if this is the double type */ - public boolean isDouble() { return tensorType.rank() == 0; } - - /** The type of this as a tensor type. The double type is the empty tensor type (rank 0) */ - public TensorType tensorType() { return tensorType; } - - /** Returns the type representing a double */ - public static ValueType doubleType() { return doubleValueType; } - - /** Returns a type representing the given tensor type */ - public static ValueType of(TensorType type) { return new ValueType(type); } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java index b4e126f69e0..8ee4cdbf297 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java @@ -4,10 +4,11 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Deque; @@ -25,7 +26,7 @@ public class GBDTForestNode extends ExpressionNode { } @Override - public final ValueType type(Context context) { return ValueType.doubleType(); } + public final TensorType type(TypeContext context) { return TensorType.empty; } @Override public final Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java index f085194a7df..aac635b2545 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java @@ -4,10 +4,11 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Deque; @@ -50,7 +51,7 @@ public final class GBDTNode extends ExpressionNode { public final double[] values() { return values; } @Override - public final ValueType type(Context context) { return ValueType.doubleType(); } + public final TensorType type(TypeContext context) { return TensorType.empty; } @Override public final Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index 816ef38e128..cdcb4df0360 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -147,14 +147,12 @@ class OperationMapper { return operation.get().map(params); } params.signature().importWarning("TensorFlow operation '" + params.node().getOp() + - "' in node '" + params.node().getName() + "' is not supported."); + "' in node '" + params.node().getName() + "' is not supported."); return Optional.empty(); } - /* - * Operations - */ + // Operations --------------------------------- private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) { Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value"); @@ -209,10 +207,11 @@ class OperationMapper { TensorType type = params.result().arguments().get(name); if (type == null) { throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + - "', but there is no such placeholder"); + "', but there is no such placeholder"); } + params.result().requiredMacro(name, type); // Included literally in the expression and so must be produced by a separate macro in the rank profile - TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name)); + TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name, type)); return Optional.of(output); } @@ -227,7 +226,7 @@ class OperationMapper { } private static Optional<TypedTensorFunction> reshape(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 2)) { + if ( ! checkInputs(params, 2)) { return Optional.empty(); } List<Optional<TypedTensorFunction>> inputs = params.inputs(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 3a6b3f23a1d..6d78b501fdc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -14,7 +14,6 @@ import org.tensorflow.framework.TensorInfo; import java.io.File; import java.io.IOException; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -109,10 +108,10 @@ public class TensorFlowImporter { } Optional<TypedTensorFunction> function = OperationMapper.map(params); - if (!function.isPresent()) { + if ( ! function.isPresent()) { return Optional.empty(); } - if (!controlDependenciesArePresent(params)) { + if ( ! controlDependenciesArePresent(params)) { return Optional.empty(); } params.imported().put(nodeName, function.get()); @@ -185,6 +184,7 @@ public class TensorFlowImporter { /** Parameter object to hold important data while importing */ static final class Parameters { + private final TensorFlowImporter owner; private final GraphDef graph; private final SavedModelBundle model; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java index 60aaf8ddce1..fe725e50a3f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java @@ -27,11 +27,13 @@ public class TensorFlowModel { private final Map<String, Tensor> constants = new HashMap<>(); private final Map<String, RankingExpression> expressions = new HashMap<>(); private final Map<String, RankingExpression> macros = new HashMap<>(); + private final Map<String, TensorType> requiredMacros = new HashMap<>(); void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } void constant(String name, Tensor constant) { constants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } void macro(String name, RankingExpression expression) { macros.put(name, expression); } + void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } /** Returns the given signature. If it does not already exist it is added to this. */ Signature signature(String name) { @@ -51,11 +53,12 @@ public class TensorFlowModel { */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } - /** - * Returns an immutable map of expressions that can be overridden - such as PlaceholderWithDefault/ - */ + /** Returns an immutable map of macros that are part of this model */ public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); } + /** Returns an immutable map of the macros that must be provided by the environment running this model */ + public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); } + /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java index d45037b6044..fc6428a4c33 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java @@ -4,8 +4,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.Join; import java.util.ArrayDeque; @@ -80,14 +80,14 @@ public final class ArithmeticNode extends CompositeNode { } @Override - public ValueType type(Context context) { + public TensorType type(TypeContext context) { // Compute type using tensor types as arithmetic operators are supported on tensors // and is correct also in the special case of doubles. // As all our functions are type-commutative, we don't need to take operator precedence into account - TensorType type = children.get(0).type(context).tensorType(); + TensorType type = children.get(0).type(context); for (int i = 1; i < children.size(); i++) - type = Join.outputType(type, children.get(i).type(context).tensorType()); - return ValueType.of(type); + type = Join.outputType(type, children.get(i).type(context)); + return type; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index fdbb22093ea..7601c0e6180 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.Deque; @@ -48,8 +49,8 @@ public class ComparisonNode extends BooleanNode { } @Override - public ValueType type(Context context) { - return ValueType.doubleType(); // by definition + public TensorType type(TypeContext context) { + return TensorType.empty; // by definition } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java index e6074a5f745..1ea8d03f0eb 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Deque; @@ -48,7 +49,7 @@ public final class ConstantNode extends ExpressionNode { } @Override - public ValueType type(Context context) { return value.type(); } + public TensorType type(TypeContext context) { return value.type(); } @Override public Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java index 8404226c33b..fd9fab99db8 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -49,7 +50,7 @@ public final class EmbracedNode extends CompositeNode { } @Override - public ValueType type(Context context) { + public TensorType type(TypeContext context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java index 5d06a562b5d..477f4db4981 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.io.Serializable; import java.util.Deque; @@ -47,7 +48,7 @@ public abstract class ExpressionNode implements Serializable { * @param context the variable type bindings to use for this evaluation * @throws IllegalArgumentException if there are variables which are not bound in the given map */ - public abstract ValueType type(Context context); + public abstract TensorType type(TypeContext context); /** * Returns the value of evaluating this expression over the given context. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java index b187b8f029c..79515229019 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java @@ -4,7 +4,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.Join; import java.util.ArrayList; @@ -66,16 +67,16 @@ public final class FunctionNode extends CompositeNode { } @Override - public ValueType type(Context context) { + public TensorType type(TypeContext context) { if (arguments.expressions().size() == 0) - return ValueType.doubleType(); + return TensorType.empty; - ValueType argument1Type = arguments.expressions().get(0).type(context); + TensorType argument1Type = arguments.expressions().get(0).type(context); if (arguments.expressions().size() == 1) return argument1Type; - ValueType argument2Type = arguments.expressions().get(1).type(context); - return ValueType.of(Join.outputType(argument1Type.tensorType(), argument2Type.tensorType())); + TensorType argument2Type = arguments.expressions().get(1).type(context); + return Join.outputType(argument1Type, argument2Type); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java index fcd40bed4d0..e42884ecc05 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -4,8 +4,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -48,7 +48,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { } @Override - public ValueType type(Context context) { return ValueType.of(type); } + public TensorType type(TypeContext context) { return type; } /** Evaluate this in a context which must have the arguments bound */ @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java index b9866bec027..076df327044 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java @@ -3,9 +3,13 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.List; /** * A conditional branch of a ranking expression. @@ -71,9 +75,9 @@ public final class IfNode extends CompositeNode { } @Override - public ValueType type(Context context) { - ValueType trueType = trueExpression.type(context); - ValueType falseType = falseExpression.type(context); + public TensorType type(TypeContext context) { + TensorType trueType = trueExpression.type(context); + TensorType falseType = falseExpression.type(context); if ( ! trueType.equals(falseType)) throw new IllegalArgumentException("An if expression must produce a value of the same type in both " + "alternatives, but the 'true' type is " + trueType + " while the " + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index b898529c4b9..da946228291 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -5,7 +5,8 @@ import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -56,8 +57,8 @@ public class LambdaFunctionNode extends CompositeNode { } @Override - public ValueType type(Context context) { - return ValueType.doubleType(); // by definition - no nested lambdas + public TensorType type(TypeContext context) { + return TensorType.empty; // by definition - no nested lambdas } /** Evaluate this in a context which must have the arguments bound */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java index cf6475238c4..f55ed59b65c 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Deque; @@ -31,7 +32,7 @@ public final class NameNode extends ExpressionNode { } @Override - public ValueType type(Context context) { throw new RuntimeException("Named nodes can not have a type"); } + public TensorType type(TypeContext context) { throw new RuntimeException("Named nodes can not have a type"); } @Override public Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java index 2e685a6c8ab..9cbe5f98c72 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -37,7 +38,7 @@ public class NegativeNode extends CompositeNode { } @Override - public ValueType type(Context context) { + public TensorType type(TypeContext context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java index c4b940f1bd6..e7041600635 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -37,7 +38,7 @@ public class NotNode extends BooleanNode { } @Override - public ValueType type(Context context) { + public TensorType type(TypeContext context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index e5176f9966d..f79297f7773 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -5,7 +5,8 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayDeque; import java.util.Deque; @@ -45,9 +46,8 @@ public final class ReferenceNode extends CompositeNode { return new ReferenceNode(name, arguments, output); } - public String getOutput() { - return output; - } + /** Returns the specific output this references, or null if none specified */ + public String getOutput() { return output; } /** Returns a copy of this node with a modified output */ public ReferenceNode setOutput(String output) { @@ -106,7 +106,7 @@ public final class ReferenceNode extends CompositeNode { } @Override - public ValueType type(Context context) { + public TensorType type(TypeContext context) { // Don't support outputs of different type, for simplicity return context.getType(name); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java index a8b82c560f7..a7b82f4753f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java @@ -6,8 +6,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.Deque; @@ -59,8 +60,8 @@ public class SetMembershipNode extends BooleanNode { } @Override - public ValueType type(Context context) { - return ValueType.doubleType(); + public TensorType type(TypeContext context) { + return TensorType.empty; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 97cfa2a5350..e4c381972e9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -5,10 +5,10 @@ import com.google.common.annotations.Beta; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.PrimitiveTensorFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; @@ -64,7 +64,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public ValueType type(Context context) { return ValueType.of(function.type(context)); } + public TensorType type(TypeContext context) { return function.type(context); } @Override public Value evaluate(Context context) { @@ -111,8 +111,8 @@ public class TensorFunctionNode extends CompositeNode { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(EvaluationContext context) { - return expression.type((Context)context).tensorType(); + public TensorType type(TypeContext context) { + return expression.type(context); } @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index b59b4750911..445ccf231a7 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -2,10 +2,12 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.TensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * @author lesters @@ -15,6 +17,18 @@ public class DropoutImportTestCase { @Test public void testDropoutImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved"); + + // Check (provided) macros + assertEquals(1, model.get().macros().size()); + assertTrue(model.get().macros().containsKey("training/input")); + assertEquals("constant(\"training/input\")", model.get().macros().get("training/input").getRoot().toString()); + + // Check required macros + assertEquals(1, model.get().requiredMacros().size()); + assertTrue(model.get().requiredMacros().containsKey("X")); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().requiredMacros().get("X")); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index f12b9a2c628..01dd15d5fa0 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -8,6 +8,7 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * @author bratseth @@ -33,6 +34,15 @@ public class MnistSoftmaxImportTestCase { constant1.type()); assertEquals(10, constant1.size()); + // Check (provided) macros + assertEquals(0, model.get().macros().size()); + + // Check required macros + assertEquals(1, model.get().requiredMacros().size()); + assertTrue(model.get().requiredMacros().containsKey("Placeholder")); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().requiredMacros().get("Placeholder")); + // Check signatures assertEquals(1, model.get().signatures().size()); TensorFlowModel.Signature signature = model.get().signatures().get("serving_default"); |