diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-01-29 14:51:23 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-01-29 14:51:23 +0100 |
commit | 1b4fde01d98bf724a54b6c1cfe3ffa4b29aec90e (patch) | |
tree | 20a127542b004eceb94e4d1344b3446df8092bd2 | |
parent | 28e3545728977a0be82159b8f278be8e772cb59b (diff) |
Propagate type information
33 files changed, 273 insertions, 120 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java index 23dd841b0ef..a4d3c111356 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java @@ -48,12 +48,11 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { * * @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and * ignoredUnknownValues is false - * @since 5.1.5 */ @Override public final void put(String name, Value value) { Integer index = nameToIndex().get(name); - if (index==null) { + if (index == null) { if (ignoreUnknownValues()) return; else @@ -70,24 +69,29 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { /** * Puts a value by index. * The value will be frozen if it isn't already. - * - * @since 5.1.5 */ public final void put(int index, Value value) { - values[index]=value.freeze(); + values[index] = value.freeze(); try { - doubleValues()[index]=value.asDouble(); + doubleValues()[index] = value.asDouble(); } catch (UnsupportedOperationException e) { - doubleValues()[index]=Double.NaN; // see getDouble below + doubleValues()[index] = Double.NaN; // see getDouble below } } + @Override + public ValueType getType(String name) { + Integer index = nameToIndex().get(name); + if (index == null) return null; + return values[index].type(); + } + /** Perform a slow lookup by name */ @Override public Value get(String name) { - Integer index=nameToIndex().get(name); - if (index==null) return DoubleValue.zero; + Integer index = nameToIndex().get(name); + if (index == null) return DoubleValue.zero; return values[index]; } @@ -100,8 +104,8 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { /** Perform a fast lookup directly of the value as a double. This is faster than get(index).asDouble() */ @Override public final double getDouble(int index) { - double value=doubleValues()[index]; - if (value==Double.NaN) + double value = doubleValues()[index]; + if (value == Double.NaN) throw new UnsupportedOperationException("Value at " + index + " has no double representation"); return value; } @@ -111,7 +115,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { * in a different thread (i.e, name name to index map, different value set. */ public ArrayContext clone() { - ArrayContext clone=(ArrayContext)super.clone(); + ArrayContext clone = (ArrayContext)super.clone(); clone.values = new Value[nameToIndex().size()]; Arrays.fill(values,constantZero); return clone; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 0eeb0a9e630..ff8758bd1e7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Set; @@ -17,17 +18,29 @@ import java.util.stream.Collectors; public abstract class Context implements EvaluationContext { /** - * <p>Returns the value of a simple variable name.</p> + * Returns the value of a simple variable name. * * @param name the name of the variable whose value to return. * @return the value of the named variable. */ public abstract Value get(String name); + /** Returns the type of the value of the given variable as a tensor type, or null if there is no such variable */ + @Override + public TensorType getTensorType(String name) { + ValueType type = getType(name); + if (type == null) return null; + if (type.isTensor()) return type.tensorType().get(); + return TensorType.empty; // double as tensor + } + /** Returns a variable as a tensor */ @Override public Tensor getTensor(String name) { return get(name).asTensor(); } + /** Returns the type of the value of the given variable, or null if there is no such variable */ + public abstract ValueType getType(String name); + /** * <p>Returns the value of a <i>structured variable</i> on the form * <code>name(argument*)(.output)?</code>, where <i>argument</i> is any diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index 2ef4a2ede2f..c85a8f1c7e1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -4,18 +4,18 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; /** * A value which acts as a double in numerical context. * * @author bratseth - * @since 5.1.21 */ public abstract class DoubleCompatibleValue extends Value { @Override + public ValueType type() { return ValueType.doubleType(); } + + @Override public boolean hasDouble() { return true; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java index ceec9358b3c..34cd75df9cb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java @@ -38,7 +38,6 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { * * @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and * ignoredUnknownValues is false - * @since 5.1.5 */ @Override public final void put(String name, Value value) { @@ -57,11 +56,7 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { doubleValues()[index] = value; } - /** - * Puts a value by index. - * - * @since 5.1.5 - */ + /** Puts a value by index. */ public final void put(int index, Value value) { try { put(index, value.asDouble()); @@ -71,6 +66,9 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { } } + @Override + public ValueType getType(String name) { return ValueType.doubleType(); } + /** Perform a slow lookup by name */ @Override public Value get(String name) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index 09895a0c2f6..2672fe6cd8e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -41,14 +41,20 @@ public class MapContext extends Context { boundValue.freeze(); } + /** Returns the type of the given value key, or null if it is not bound. */ + @Override + public ValueType getType(String key) { + Value value = bindings.get(key); + if (value == null) return null; + return value.type(); + } + /** * Returns the value of a key. 0 is returned if the given key is not bound in this. */ @Override public Value get(String key) { - Value value = bindings.get(key); - if (value == null) return DoubleValue.zero; - return value; + return bindings.getOrDefault(key, DoubleValue.zero); } /** @@ -67,7 +73,7 @@ public class MapContext extends Context { if (frozen) return bindings; return Collections.unmodifiableMap(bindings); } - + /** Returns a new, modifiable context containing all the bindings of this */ public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index dad69b31181..874b41ec3e1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -5,8 +5,6 @@ import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; /** * A string value. @@ -30,6 +28,9 @@ public class StringValue extends Value { this.value = value; } + @Override + public ValueType type() { return ValueType.doubleType(); } + /** Returns the hashcode of this, to enable strings to be encoded (with reasonable safely) as doubles for optimization */ @Override public double asDouble() { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 26c30fe5ed2..f1c65dc79d3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -25,6 +25,9 @@ public class TensorValue extends Value { } @Override + public ValueType type() { return ValueType.doubleType(); } + + @Override public double asDouble() { if (hasDouble()) return value.get(TensorAddress.of()); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 40d70e0022c..856bfb3638d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -19,6 +19,9 @@ public abstract class Value { private boolean frozen=false; + /** Returns the type of this value */ + public abstract ValueType type(); + /** Returns this value as a double, or throws UnsupportedOperationException if it cannot be represented as a double */ public abstract double asDouble(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java new file mode 100644 index 00000000000..06301372dcc --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java @@ -0,0 +1,37 @@ +package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import com.yahoo.tensor.TensorType; + +import java.util.Optional; + +/** + * The type of a ranking expression value - either a double or a tensor. + * + * @author bratseth + */ +public class ValueType { + + private static final ValueType doubleValueType = new ValueType(Optional.empty()); + + private final Optional<TensorType> tensorType; + + private ValueType(Optional<TensorType> type) { + this.tensorType = type; + } + + /** Returns true if this is a double type */ + public boolean isDouble() { return ! tensorType.isPresent(); } + + /** Returns true if this is a tensor type */ + public boolean isTensor() { return tensorType.isPresent(); } + + /** The specific tensor type of this, or empty if this is not a tensor type */ + public Optional<TensorType> tensorType() { return tensorType; } + + /** Returns the type representing a double */ + public static ValueType doubleType() { return doubleValueType; } + + /** Returns a type representing the given tensor type */ + public static ValueType tensorType(TensorType type) { return new ValueType(Optional.of(type)); } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java index eb303fc6446..31984dca54d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import java.io.Serializable; import java.util.Deque; @@ -41,6 +42,14 @@ public abstract class ExpressionNode implements Serializable { public abstract String toString(SerializationContext context, Deque<String> path, CompositeNode parent); /** + * Returns the type this will return if evaluated with the given context. + * + * @param context the variable type bindings to use for this evaluation + * @throws IllegalArgumentException if there are variables which are not bound in the given map + */ + public ValueType type(Context context) { return ValueType.doubleType(); } // double is default + + /** * Returns the value of evaluating this expression over the given context. * * @param context the variable bindings to use for this evaluation diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index b42570d3aea..c85a15ada64 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -5,7 +5,9 @@ import com.google.common.annotations.Beta; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.functions.PrimitiveTensorFunction; import com.yahoo.tensor.functions.TensorFunction; @@ -35,7 +37,7 @@ public class TensorFunctionNode extends CompositeNode { @Override public List<ExpressionNode> children() { - return function.functionArguments().stream() + return function.arguments().stream() .map(this::toExpressionNode) .collect(Collectors.toList()); } @@ -52,7 +54,7 @@ public class TensorFunctionNode extends CompositeNode { List<TensorFunction> wrappedChildren = children.stream() .map(TensorFunctionExpressionNode::new) .collect(Collectors.toList()); - return new TensorFunctionNode(function.replaceArguments(wrappedChildren)); + return new TensorFunctionNode(function.withArguments(wrappedChildren)); } @Override @@ -62,6 +64,9 @@ public class TensorFunctionNode extends CompositeNode { } @Override + public ValueType type(Context context) { return ValueType.tensorType(function.type(context)); } + + @Override public Value evaluate(Context context) { return new TensorValue(function.evaluate(context)); } @@ -84,7 +89,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public List<TensorFunction> functionArguments() { + public List<TensorFunction> arguments() { if (expression instanceof CompositeNode) return ((CompositeNode)expression).children().stream() .map(TensorFunctionExpressionNode::new) @@ -94,7 +99,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if (arguments.size() == 0) return this; List<ExpressionNode> unwrappedChildren = arguments.stream() .map(arg -> ((TensorFunctionExpressionNode)arg).expression) @@ -106,12 +111,17 @@ public class TensorFunctionNode extends CompositeNode { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(EvaluationContext context) { + return expression.type((Context)context).tensorType().orElse(TensorType.empty); + } + + @Override public Tensor evaluate(EvaluationContext context) { Value result = expression.evaluate((Context)context); if ( ! ( result instanceof TensorValue)) throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " + "but this returns " + result + ", not a tensor"); - return ((TensorValue)result).asTensor(); + return result.asTensor(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java index 3db661f8a23..e18b77a0434 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; /** * An evaluation context which is passed down to all nested functions during evaluation. @@ -12,6 +13,14 @@ import com.yahoo.tensor.Tensor; @Beta public interface EvaluationContext { + /** + * Returns tye type of the tensor with this name. + * + * @return returns the type of the tensor which will be returned by calling getTensor(name) + * or null if getTensor will return null. + */ + TensorType getTensorType(String name); + /** Returns the tensor bound to this name, or null if none */ Tensor getTensor(String name); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index db8a66a5fa2..6bdfe8f19b6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import java.util.HashMap; @@ -19,6 +20,13 @@ public class MapEvaluationContext implements EvaluationContext { public void put(String name, Tensor tensor) { bindings.put(name, tensor); } @Override + public TensorType getTensorType(String name) { + Tensor tensor = bindings.get(name); + if (tensor == null) return null; + return tensor.type(); + } + + @Override public Tensor getTensor(String name) { return bindings.get(name); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 1f6ad050368..6c149724aca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.PrimitiveTensorFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; @@ -25,15 +26,18 @@ public class VariableTensor extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { return this; } + public TensorFunction withArguments(List<TensorFunction> arguments) { return this; } @Override public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(EvaluationContext context) { return context.getTensorType(name); } + + @Override public Tensor evaluate(EvaluationContext context) { return context.getTensor(name); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java index 10f53670826..93365d20966 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java @@ -14,17 +14,17 @@ public class Argmax extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public Argmax(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Argmax must have 1 argument, got " + arguments.size()); return new Argmax(arguments.get(0), dimension); @@ -37,7 +37,7 @@ public class Argmax extends CompositeTensorFunction { new Reduce(primitiveArgument, Reduce.Aggregator.max, dimension), ScalarFunctions.equal()); } - + @Override public String toString(ToStringContext context) { return "argmax(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java index d324aec53e9..e598cdf8a98 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java @@ -14,17 +14,17 @@ public class Argmin extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public Argmin(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Argmin must have 1 argument, got " + arguments.size()); return new Argmin(arguments.get(0), dimension); @@ -37,7 +37,7 @@ public class Argmin extends CompositeTensorFunction { new Reduce(primitiveArgument, Reduce.Aggregator.min, dimension), ScalarFunctions.equal()); } - + @Override public String toString(ToStringContext context) { return "argmin(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 191c7988443..0c43caef05c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; /** @@ -14,6 +15,10 @@ import com.yahoo.tensor.evaluation.EvaluationContext; @Beta public abstract class CompositeTensorFunction extends TensorFunction { + /** Finds the type this produces by first converting it to a primitive function */ + @Override + public final TensorType type(EvaluationContext context) { return toPrimitive().type(context); } + /** Evaluates this by first converting it to a primitive function */ @Override public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index d4affe0ef9b..cc8067224c7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -34,10 +34,10 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if (arguments.size() != 2) throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size()); return new Concat(arguments.get(0), arguments.get(1), dimension); @@ -54,6 +54,20 @@ public class Concat extends PrimitiveTensorFunction { } @Override + public TensorType type(EvaluationContext context) { + return type(argumentA.type(context), argumentB.type(context)); + } + + /** Returns the type resulting from concatenating a and b */ + private TensorType type(TensorType a, TensorType b) { + TensorType.Builder builder = new TensorType.Builder(a, b); + if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size + builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() + + b.dimension(dimension).get().size().get())); + return builder.build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); @@ -63,7 +77,7 @@ public class Concat extends PrimitiveTensorFunction { IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor IndexedTensor bIndexed = (IndexedTensor) b; - TensorType concatType = concatType(a, b); + TensorType concatType = type(a.type(), b.type()); DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); @@ -115,15 +129,6 @@ public class Concat extends PrimitiveTensorFunction { } - /** Returns the type resulting from concatenating a and b */ - private TensorType concatType(Tensor a, Tensor b) { - TensorType.Builder builder = new TensorType.Builder(a.type(), b.type()); - if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size - builder.set(TensorType.Dimension.indexed(dimension, a.type().dimension(dimension).get().size().get() + - b.type().dimension(dimension).get().size().get())); - return builder.build(); - } - /** Returns the concrete (not type) dimension sizes resulting from combining a and b */ private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 14ed38718ce..4a6d656142f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Collections; @@ -27,10 +28,10 @@ public class ConstantTensor extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("ConstantTensor must have 0 arguments, got " + arguments.size()); return this; @@ -40,6 +41,9 @@ public class ConstantTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(EvaluationContext context) { return constant.type(); } + + @Override public Tensor evaluate(EvaluationContext context) { return constant; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index 653be8dacf0..e302f6606e7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -25,10 +25,10 @@ public class Diag extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size()); return this; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index ef2770c04f5..ff9589bd6ae 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -47,10 +47,10 @@ public class Generate extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size()); return this; @@ -60,6 +60,9 @@ public class Generate extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(EvaluationContext context) { return type; } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 174a8e4c435..01c681bfb36 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -70,15 +70,13 @@ public class Join extends PrimitiveTensorFunction { return typeBuilder.build(); } - public TensorFunction argumentA() { return argumentA; } - public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size()); return new Join(arguments.get(0), arguments.get(1), combinator); @@ -95,6 +93,11 @@ public class Join extends PrimitiveTensorFunction { } @Override + public TensorType type(EvaluationContext context) { + return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index 91a9c6d1b27..d7f7ae59d62 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -14,17 +14,17 @@ public class L1Normalize extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public L1Normalize(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("L1Normalize must have 1 argument, got " + arguments.size()); return new L1Normalize(arguments.get(0), dimension); @@ -38,7 +38,7 @@ public class L1Normalize extends CompositeTensorFunction { new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension), ScalarFunctions.divide()); } - + @Override public String toString(ToStringContext context) { return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index bdf8921f81d..e2c526760bd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -14,17 +14,17 @@ public class L2Normalize extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public L2Normalize(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("L2Normalize must have 1 argument, got " + arguments.size()); return new L2Normalize(arguments.get(0), dimension); @@ -40,7 +40,7 @@ public class L2Normalize extends CompositeTensorFunction { ScalarFunctions.sqrt()), ScalarFunctions.divide()); } - + @Override public String toString(ToStringContext context) { return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index a5e1a016a41..e5440b56c54 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -2,8 +2,6 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; -import com.google.common.collect.ImmutableMap; -import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -39,10 +37,10 @@ public class Map extends PrimitiveTensorFunction { public DoubleUnaryOperator mapper() { return mapper; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Map must have 1 argument, got " + arguments.size()); return new Map(arguments.get(0), mapper); @@ -54,6 +52,11 @@ public class Map extends PrimitiveTensorFunction { } @Override + public TensorType type(EvaluationContext context) { + return argument.type(context); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor argument = argument().evaluate(context); Tensor.Builder builder = Tensor.Builder.of(argument.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 4071917c2b5..935e4761cfe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -27,10 +27,10 @@ public class Matmul extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); } + public List<TensorFunction> arguments() { return ImmutableList.of(argument1, argument2); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size()); return new Matmul(arguments.get(0), arguments.get(1), dimension); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 958ef85d1dc..1475f7f4ac1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -28,10 +28,10 @@ public class Random extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size()); return this; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index 8e7f4e4c773..d951ec9ccbd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -26,10 +26,10 @@ public class Range extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size()); return this; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index de9f90a5804..591a6e4649e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -73,10 +73,10 @@ public class Reduce extends PrimitiveTensorFunction { public TensorFunction argument() { return argument; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size()); return new Reduce(arguments.get(0), aggregator, dimensions); @@ -100,6 +100,19 @@ public class Reduce extends PrimitiveTensorFunction { } @Override + public TensorType type(EvaluationContext context) { + return type(argument.type(context)); + } + + private TensorType type(TensorType argumentType) { + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension dimension : argumentType.dimensions()) + if ( ! dimensions.contains(dimension.name())) // keep + builder.dimension(dimension); + return builder.build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) @@ -113,12 +126,7 @@ public class Reduce extends PrimitiveTensorFunction { else return reduceAllGeneral(argument); - // Reduce type - TensorType.Builder builder = new TensorType.Builder(); - for (TensorType.Dimension dimension : argument.type().dimensions()) - if ( ! dimensions.contains(dimension.name())) // keep - builder.dimension(dimension); - TensorType reducedType = builder.build(); + TensorType reducedType = type(argument.type()); // Reduce cells Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index ec9b762a41c..6a9b8d68b38 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -26,6 +26,7 @@ public class Rename extends PrimitiveTensorFunction { private final TensorFunction argument; private final List<String> fromDimensions; private final List<String> toDimensions; + private final Map<String, String> fromToMap; public Rename(TensorFunction argument, String fromDimension, String toDimension) { this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); @@ -43,13 +44,24 @@ public class Rename extends PrimitiveTensorFunction { this.argument = argument; this.fromDimensions = ImmutableList.copyOf(fromDimensions); this.toDimensions = ImmutableList.copyOf(toDimensions); + this.fromToMap = fromToMap(fromDimensions, toDimensions); + } + + public List<String> fromDimensions() { return fromDimensions; } + public List<String> toDimensions() { return toDimensions; } + + private static Map<String, String> fromToMap(List<String> fromDimensions, List<String> toDimensions) { + Map<String, String> map = new HashMap<>(); + for (int i = 0; i < fromDimensions.size(); i++) + map.put(fromDimensions.get(i), toDimensions.get(i)); + return map; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size()); return new Rename(arguments.get(0), fromDimensions, toDimensions); @@ -59,11 +71,22 @@ public class Rename extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(EvaluationContext context) { + return type(argument.type(context)); + } + + private TensorType type(TensorType type) { + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension dimension : type.dimensions()) + builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); + return builder.build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor tensor = argument.evaluate(context); - Map<String, String> fromToMap = fromToMap(); - TensorType renamedType = rename(tensor.type(), fromToMap); + TensorType renamedType = type(tensor.type()); // an array which lists the index of each label in the renamed type int[] toIndexes = new int[tensor.type().dimensions().size()]; @@ -82,13 +105,6 @@ public class Rename extends PrimitiveTensorFunction { return builder.build(); } - private TensorType rename(TensorType type, Map<String, String> fromToMap) { - TensorType.Builder builder = new TensorType.Builder(); - for (TensorType.Dimension dimension : type.dimensions()) - builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); - return builder.build(); - } - private TensorAddress rename(TensorAddress address, int[] toIndexes) { String[] reorderedLabels = new String[toIndexes.length]; for (int i = 0; i < toIndexes.length; i++) @@ -102,13 +118,6 @@ public class Rename extends PrimitiveTensorFunction { toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } - private Map<String, String> fromToMap() { - Map<String, String> map = new HashMap<>(); - for (int i = 0; i < fromDimensions.size(); i++) - map.put(fromDimensions.get(i), toDimensions.get(i)); - return map; - } - private String toVectorString(List<String> elements) { if (elements.size() == 1) return elements.get(0); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index c856b548180..32cff5ac84a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -16,21 +16,21 @@ public class Softmax extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public Softmax(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } - + public static TensorType outputType(TensorType inputType, String dimension) { return Reduce.outputType(inputType, ImmutableList.of(dimension)); } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size()); return new Softmax(arguments.get(0), dimension); @@ -45,7 +45,7 @@ public class Softmax extends CompositeTensorFunction { dimension), ScalarFunctions.divide()); } - + @Override public String toString(ToStringContext context) { return "softmax(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 533a46f87fe..3f6dfae6222 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.MapEvaluationContext; @@ -19,14 +20,14 @@ import java.util.List; public abstract class TensorFunction { /** Returns the function arguments of this node in the order they are applied */ - public abstract List<TensorFunction> functionArguments(); + public abstract List<TensorFunction> arguments(); /** * Returns a copy of this tensor function with the arguments replaced by the given list of arguments. * * @throws IllegalArgumentException if the argument list has the wrong size for this function */ - public abstract TensorFunction replaceArguments(List<TensorFunction> arguments); + public abstract TensorFunction withArguments(List<TensorFunction> arguments); /** * Translate this function - and all of its arguments recursively - @@ -43,6 +44,13 @@ public abstract class TensorFunction { */ public abstract Tensor evaluate(EvaluationContext context); + /** + * Returns the type of the tensor this produces given the input types in the context + * + * @param context a context which must be passed to all nexted functions when evaluating + */ + public abstract TensorType type(EvaluationContext context); + /** Evaluate with no context */ public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java index 2464be981f5..78ff0731566 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -14,7 +14,7 @@ public class XwPlusB extends CompositeTensorFunction { private final TensorFunction x, w, b; private final String dimension; - + public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) { this.x = x; this.w = w; @@ -23,10 +23,10 @@ public class XwPlusB extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); } + public List<TensorFunction> arguments() { return ImmutableList.of(x, w, b); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 3) throw new IllegalArgumentException("XwPlusB must have 3 arguments, got " + arguments.size()); return new XwPlusB(arguments.get(0), arguments.get(1), arguments.get(2), dimension); @@ -43,7 +43,7 @@ public class XwPlusB extends CompositeTensorFunction { primitiveB, ScalarFunctions.add()); } - + @Override public String toString(ToStringContext context) { return "xw_plus_b(" + x.toString(context) + ", " + |