aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-01-29 14:51:23 +0100
committerJon Bratseth <bratseth@oath.com>2018-01-29 14:51:23 +0100
commit1b4fde01d98bf724a54b6c1cfe3ffa4b29aec90e (patch)
tree20a127542b004eceb94e4d1344b3446df8092bd2
parent28e3545728977a0be82159b8f278be8e772cb59b (diff)
Propagate type information
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java28
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java37
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java20
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java29
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java45
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java8
33 files changed, 273 insertions, 120 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
index 23dd841b0ef..a4d3c111356 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
@@ -48,12 +48,11 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
*
* @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and
* ignoredUnknownValues is false
- * @since 5.1.5
*/
@Override
public final void put(String name, Value value) {
Integer index = nameToIndex().get(name);
- if (index==null) {
+ if (index == null) {
if (ignoreUnknownValues())
return;
else
@@ -70,24 +69,29 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
/**
* Puts a value by index.
* The value will be frozen if it isn't already.
- *
- * @since 5.1.5
*/
public final void put(int index, Value value) {
- values[index]=value.freeze();
+ values[index] = value.freeze();
try {
- doubleValues()[index]=value.asDouble();
+ doubleValues()[index] = value.asDouble();
}
catch (UnsupportedOperationException e) {
- doubleValues()[index]=Double.NaN; // see getDouble below
+ doubleValues()[index] = Double.NaN; // see getDouble below
}
}
+ @Override
+ public ValueType getType(String name) {
+ Integer index = nameToIndex().get(name);
+ if (index == null) return null;
+ return values[index].type();
+ }
+
/** Perform a slow lookup by name */
@Override
public Value get(String name) {
- Integer index=nameToIndex().get(name);
- if (index==null) return DoubleValue.zero;
+ Integer index = nameToIndex().get(name);
+ if (index == null) return DoubleValue.zero;
return values[index];
}
@@ -100,8 +104,8 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
/** Perform a fast lookup directly of the value as a double. This is faster than get(index).asDouble() */
@Override
public final double getDouble(int index) {
- double value=doubleValues()[index];
- if (value==Double.NaN)
+ double value = doubleValues()[index];
+ if (value == Double.NaN)
throw new UnsupportedOperationException("Value at " + index + " has no double representation");
return value;
}
@@ -111,7 +115,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
* in a different thread (i.e, name name to index map, different value set.
*/
public ArrayContext clone() {
- ArrayContext clone=(ArrayContext)super.clone();
+ ArrayContext clone = (ArrayContext)super.clone();
clone.values = new Value[nameToIndex().size()];
Arrays.fill(values,constantZero);
return clone;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 0eeb0a9e630..ff8758bd1e7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Set;
@@ -17,17 +18,29 @@ import java.util.stream.Collectors;
public abstract class Context implements EvaluationContext {
/**
- * <p>Returns the value of a simple variable name.</p>
+ * Returns the value of a simple variable name.
*
* @param name the name of the variable whose value to return.
* @return the value of the named variable.
*/
public abstract Value get(String name);
+ /** Returns the type of the value of the given variable as a tensor type, or null if there is no such variable */
+ @Override
+ public TensorType getTensorType(String name) {
+ ValueType type = getType(name);
+ if (type == null) return null;
+ if (type.isTensor()) return type.tensorType().get();
+ return TensorType.empty; // double as tensor
+ }
+
/** Returns a variable as a tensor */
@Override
public Tensor getTensor(String name) { return get(name).asTensor(); }
+ /** Returns the type of the value of the given variable, or null if there is no such variable */
+ public abstract ValueType getType(String name);
+
/**
* <p>Returns the value of a <i>structured variable</i> on the form
* <code>name(argument*)(.output)?</code>, where <i>argument</i> is any
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index 2ef4a2ede2f..c85a8f1c7e1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -4,18 +4,18 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-import com.yahoo.tensor.TensorType;
/**
* A value which acts as a double in numerical context.
*
* @author bratseth
- * @since 5.1.21
*/
public abstract class DoubleCompatibleValue extends Value {
@Override
+ public ValueType type() { return ValueType.doubleType(); }
+
+ @Override
public boolean hasDouble() { return true; }
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
index ceec9358b3c..34cd75df9cb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
@@ -38,7 +38,6 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
*
* @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and
* ignoredUnknownValues is false
- * @since 5.1.5
*/
@Override
public final void put(String name, Value value) {
@@ -57,11 +56,7 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
doubleValues()[index] = value;
}
- /**
- * Puts a value by index.
- *
- * @since 5.1.5
- */
+ /** Puts a value by index. */
public final void put(int index, Value value) {
try {
put(index, value.asDouble());
@@ -71,6 +66,9 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
}
}
+ @Override
+ public ValueType getType(String name) { return ValueType.doubleType(); }
+
/** Perform a slow lookup by name */
@Override
public Value get(String name) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
index 09895a0c2f6..2672fe6cd8e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
@@ -41,14 +41,20 @@ public class MapContext extends Context {
boundValue.freeze();
}
+ /** Returns the type of the given value key, or null if it is not bound. */
+ @Override
+ public ValueType getType(String key) {
+ Value value = bindings.get(key);
+ if (value == null) return null;
+ return value.type();
+ }
+
/**
* Returns the value of a key. 0 is returned if the given key is not bound in this.
*/
@Override
public Value get(String key) {
- Value value = bindings.get(key);
- if (value == null) return DoubleValue.zero;
- return value;
+ return bindings.getOrDefault(key, DoubleValue.zero);
}
/**
@@ -67,7 +73,7 @@ public class MapContext extends Context {
if (frozen) return bindings;
return Collections.unmodifiableMap(bindings);
}
-
+
/** Returns a new, modifiable context containing all the bindings of this */
public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index dad69b31181..874b41ec3e1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -5,8 +5,6 @@ import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-import com.yahoo.tensor.TensorType;
/**
* A string value.
@@ -30,6 +28,9 @@ public class StringValue extends Value {
this.value = value;
}
+ @Override
+ public ValueType type() { return ValueType.doubleType(); }
+
/** Returns the hashcode of this, to enable strings to be encoded (with reasonable safely) as doubles for optimization */
@Override
public double asDouble() {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 26c30fe5ed2..f1c65dc79d3 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -25,6 +25,9 @@ public class TensorValue extends Value {
}
@Override
+ public ValueType type() { return ValueType.doubleType(); }
+
+ @Override
public double asDouble() {
if (hasDouble())
return value.get(TensorAddress.of());
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 40d70e0022c..856bfb3638d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -19,6 +19,9 @@ public abstract class Value {
private boolean frozen=false;
+ /** Returns the type of this value */
+ public abstract ValueType type();
+
/** Returns this value as a double, or throws UnsupportedOperationException if it cannot be represented as a double */
public abstract double asDouble();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java
new file mode 100644
index 00000000000..06301372dcc
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java
@@ -0,0 +1,37 @@
+package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import com.yahoo.tensor.TensorType;
+
+import java.util.Optional;
+
+/**
+ * The type of a ranking expression value - either a double or a tensor.
+ *
+ * @author bratseth
+ */
+public class ValueType {
+
+ private static final ValueType doubleValueType = new ValueType(Optional.empty());
+
+ private final Optional<TensorType> tensorType;
+
+ private ValueType(Optional<TensorType> type) {
+ this.tensorType = type;
+ }
+
+ /** Returns true if this is a double type */
+ public boolean isDouble() { return ! tensorType.isPresent(); }
+
+ /** Returns true if this is a tensor type */
+ public boolean isTensor() { return tensorType.isPresent(); }
+
+ /** The specific tensor type of this, or empty if this is not a tensor type */
+ public Optional<TensorType> tensorType() { return tensorType; }
+
+ /** Returns the type representing a double */
+ public static ValueType doubleType() { return doubleValueType; }
+
+ /** Returns a type representing the given tensor type */
+ public static ValueType tensorType(TensorType type) { return new ValueType(Optional.of(type)); }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
index eb303fc6446..31984dca54d 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import java.io.Serializable;
import java.util.Deque;
@@ -41,6 +42,14 @@ public abstract class ExpressionNode implements Serializable {
public abstract String toString(SerializationContext context, Deque<String> path, CompositeNode parent);
/**
+ * Returns the type this will return if evaluated with the given context.
+ *
+ * @param context the variable type bindings to use for this evaluation
+ * @throws IllegalArgumentException if there are variables which are not bound in the given map
+ */
+ public ValueType type(Context context) { return ValueType.doubleType(); } // double is default
+
+ /**
* Returns the value of evaluating this expression over the given context.
*
* @param context the variable bindings to use for this evaluation
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index b42570d3aea..c85a15ada64 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -5,7 +5,9 @@ import com.google.common.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.evaluation.ValueType;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
@@ -35,7 +37,7 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public List<ExpressionNode> children() {
- return function.functionArguments().stream()
+ return function.arguments().stream()
.map(this::toExpressionNode)
.collect(Collectors.toList());
}
@@ -52,7 +54,7 @@ public class TensorFunctionNode extends CompositeNode {
List<TensorFunction> wrappedChildren = children.stream()
.map(TensorFunctionExpressionNode::new)
.collect(Collectors.toList());
- return new TensorFunctionNode(function.replaceArguments(wrappedChildren));
+ return new TensorFunctionNode(function.withArguments(wrappedChildren));
}
@Override
@@ -62,6 +64,9 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
+ public ValueType type(Context context) { return ValueType.tensorType(function.type(context)); }
+
+ @Override
public Value evaluate(Context context) {
return new TensorValue(function.evaluate(context));
}
@@ -84,7 +89,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public List<TensorFunction> functionArguments() {
+ public List<TensorFunction> arguments() {
if (expression instanceof CompositeNode)
return ((CompositeNode)expression).children().stream()
.map(TensorFunctionExpressionNode::new)
@@ -94,7 +99,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if (arguments.size() == 0) return this;
List<ExpressionNode> unwrappedChildren = arguments.stream()
.map(arg -> ((TensorFunctionExpressionNode)arg).expression)
@@ -106,12 +111,17 @@ public class TensorFunctionNode extends CompositeNode {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
+ public TensorType type(EvaluationContext context) {
+ return expression.type((Context)context).tensorType().orElse(TensorType.empty);
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Value result = expression.evaluate((Context)context);
if ( ! ( result instanceof TensorValue))
throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " +
"but this returns " + result + ", not a tensor");
- return ((TensorValue)result).asTensor();
+ return result.asTensor();
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
index 3db661f8a23..e18b77a0434 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.evaluation;
import com.google.common.annotations.Beta;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
/**
* An evaluation context which is passed down to all nested functions during evaluation.
@@ -12,6 +13,14 @@ import com.yahoo.tensor.Tensor;
@Beta
public interface EvaluationContext {
+ /**
+ * Returns tye type of the tensor with this name.
+ *
+ * @return returns the type of the tensor which will be returned by calling getTensor(name)
+ * or null if getTensor will return null.
+ */
+ TensorType getTensorType(String name);
+
/** Returns the tensor bound to this name, or null if none */
Tensor getTensor(String name);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
index db8a66a5fa2..6bdfe8f19b6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.evaluation;
import com.google.common.annotations.Beta;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import java.util.HashMap;
@@ -19,6 +20,13 @@ public class MapEvaluationContext implements EvaluationContext {
public void put(String name, Tensor tensor) { bindings.put(name, tensor); }
@Override
+ public TensorType getTensorType(String name) {
+ Tensor tensor = bindings.get(name);
+ if (tensor == null) return null;
+ return tensor.type();
+ }
+
+ @Override
public Tensor getTensor(String name) { return bindings.get(name); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index 1f6ad050368..6c149724aca 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.evaluation;
import com.google.common.annotations.Beta;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
@@ -25,15 +26,18 @@ public class VariableTensor extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) { return this; }
+ public TensorFunction withArguments(List<TensorFunction> arguments) { return this; }
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
+ public TensorType type(EvaluationContext context) { return context.getTensorType(name); }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
return context.getTensor(name);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
index 10f53670826..93365d20966 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
@@ -14,17 +14,17 @@ public class Argmax extends CompositeTensorFunction {
private final TensorFunction argument;
private final String dimension;
-
+
public Argmax(TensorFunction argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Argmax must have 1 argument, got " + arguments.size());
return new Argmax(arguments.get(0), dimension);
@@ -37,7 +37,7 @@ public class Argmax extends CompositeTensorFunction {
new Reduce(primitiveArgument, Reduce.Aggregator.max, dimension),
ScalarFunctions.equal());
}
-
+
@Override
public String toString(ToStringContext context) {
return "argmax(" + argument.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
index d324aec53e9..e598cdf8a98 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
@@ -14,17 +14,17 @@ public class Argmin extends CompositeTensorFunction {
private final TensorFunction argument;
private final String dimension;
-
+
public Argmin(TensorFunction argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Argmin must have 1 argument, got " + arguments.size());
return new Argmin(arguments.get(0), dimension);
@@ -37,7 +37,7 @@ public class Argmin extends CompositeTensorFunction {
new Reduce(primitiveArgument, Reduce.Aggregator.min, dimension),
ScalarFunctions.equal());
}
-
+
@Override
public String toString(ToStringContext context) {
return "argmin(" + argument.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 191c7988443..0c43caef05c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
/**
@@ -14,6 +15,10 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
@Beta
public abstract class CompositeTensorFunction extends TensorFunction {
+ /** Finds the type this produces by first converting it to a primitive function */
+ @Override
+ public final TensorType type(EvaluationContext context) { return toPrimitive().type(context); }
+
/** Evaluates this by first converting it to a primitive function */
@Override
public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index d4affe0ef9b..cc8067224c7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -34,10 +34,10 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); }
+ public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if (arguments.size() != 2)
throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
return new Concat(arguments.get(0), arguments.get(1), dimension);
@@ -54,6 +54,20 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
+ public TensorType type(EvaluationContext context) {
+ return type(argumentA.type(context), argumentB.type(context));
+ }
+
+ /** Returns the type resulting from concatenating a and b */
+ private TensorType type(TensorType a, TensorType b) {
+ TensorType.Builder builder = new TensorType.Builder(a, b);
+ if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
+ builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() +
+ b.dimension(dimension).get().size().get()));
+ return builder.build();
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
@@ -63,7 +77,7 @@ public class Concat extends PrimitiveTensorFunction {
IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
IndexedTensor bIndexed = (IndexedTensor) b;
- TensorType concatType = concatType(a, b);
+ TensorType concatType = type(a.type(), b.type());
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
@@ -115,15 +129,6 @@ public class Concat extends PrimitiveTensorFunction {
}
- /** Returns the type resulting from concatenating a and b */
- private TensorType concatType(Tensor a, Tensor b) {
- TensorType.Builder builder = new TensorType.Builder(a.type(), b.type());
- if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
- builder.set(TensorType.Dimension.indexed(dimension, a.type().dimension(dimension).get().size().get() +
- b.type().dimension(dimension).get().size().get()));
- return builder.build();
- }
-
/** Returns the concrete (not type) dimension sizes resulting from combining a and b */
private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 14ed38718ce..4a6d656142f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Collections;
@@ -27,10 +28,10 @@ public class ConstantTensor extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("ConstantTensor must have 0 arguments, got " + arguments.size());
return this;
@@ -40,6 +41,9 @@ public class ConstantTensor extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
+ public TensorType type(EvaluationContext context) { return constant.type(); }
+
+ @Override
public Tensor evaluate(EvaluationContext context) { return constant; }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index 653be8dacf0..e302f6606e7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -25,10 +25,10 @@ public class Diag extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size());
return this;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index ef2770c04f5..ff9589bd6ae 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -47,10 +47,10 @@ public class Generate extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size());
return this;
@@ -60,6 +60,9 @@ public class Generate extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
+ public TensorType type(EvaluationContext context) { return type; }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor.Builder builder = Tensor.Builder.of(type);
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type));
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 174a8e4c435..01c681bfb36 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -70,15 +70,13 @@ public class Join extends PrimitiveTensorFunction {
return typeBuilder.build();
}
- public TensorFunction argumentA() { return argumentA; }
- public TensorFunction argumentB() { return argumentB; }
public DoubleBinaryOperator combinator() { return combinator; }
@Override
- public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); }
+ public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size());
return new Join(arguments.get(0), arguments.get(1), combinator);
@@ -95,6 +93,11 @@ public class Join extends PrimitiveTensorFunction {
}
@Override
+ public TensorType type(EvaluationContext context) {
+ return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build();
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
index 91a9c6d1b27..d7f7ae59d62 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
@@ -14,17 +14,17 @@ public class L1Normalize extends CompositeTensorFunction {
private final TensorFunction argument;
private final String dimension;
-
+
public L1Normalize(TensorFunction argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("L1Normalize must have 1 argument, got " + arguments.size());
return new L1Normalize(arguments.get(0), dimension);
@@ -38,7 +38,7 @@ public class L1Normalize extends CompositeTensorFunction {
new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension),
ScalarFunctions.divide());
}
-
+
@Override
public String toString(ToStringContext context) {
return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
index bdf8921f81d..e2c526760bd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -14,17 +14,17 @@ public class L2Normalize extends CompositeTensorFunction {
private final TensorFunction argument;
private final String dimension;
-
+
public L2Normalize(TensorFunction argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("L2Normalize must have 1 argument, got " + arguments.size());
return new L2Normalize(arguments.get(0), dimension);
@@ -40,7 +40,7 @@ public class L2Normalize extends CompositeTensorFunction {
ScalarFunctions.sqrt()),
ScalarFunctions.divide());
}
-
+
@Override
public String toString(ToStringContext context) {
return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index a5e1a016a41..e5440b56c54 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -2,8 +2,6 @@
package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -39,10 +37,10 @@ public class Map extends PrimitiveTensorFunction {
public DoubleUnaryOperator mapper() { return mapper; }
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Map must have 1 argument, got " + arguments.size());
return new Map(arguments.get(0), mapper);
@@ -54,6 +52,11 @@ public class Map extends PrimitiveTensorFunction {
}
@Override
+ public TensorType type(EvaluationContext context) {
+ return argument.type(context);
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor argument = argument().evaluate(context);
Tensor.Builder builder = Tensor.Builder.of(argument.type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index 4071917c2b5..935e4761cfe 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -27,10 +27,10 @@ public class Matmul extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); }
+ public List<TensorFunction> arguments() { return ImmutableList.of(argument1, argument2); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size());
return new Matmul(arguments.get(0), arguments.get(1), dimension);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 958ef85d1dc..1475f7f4ac1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -28,10 +28,10 @@ public class Random extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size());
return this;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index 8e7f4e4c773..d951ec9ccbd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -26,10 +26,10 @@ public class Range extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size());
return this;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index de9f90a5804..591a6e4649e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -73,10 +73,10 @@ public class Reduce extends PrimitiveTensorFunction {
public TensorFunction argument() { return argument; }
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size());
return new Reduce(arguments.get(0), aggregator, dimensions);
@@ -100,6 +100,19 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
+ public TensorType type(EvaluationContext context) {
+ return type(argument.type(context));
+ }
+
+ private TensorType type(TensorType argumentType) {
+ TensorType.Builder builder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : argumentType.dimensions())
+ if ( ! dimensions.contains(dimension.name())) // keep
+ builder.dimension(dimension);
+ return builder.build();
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
@@ -113,12 +126,7 @@ public class Reduce extends PrimitiveTensorFunction {
else
return reduceAllGeneral(argument);
- // Reduce type
- TensorType.Builder builder = new TensorType.Builder();
- for (TensorType.Dimension dimension : argument.type().dimensions())
- if ( ! dimensions.contains(dimension.name())) // keep
- builder.dimension(dimension);
- TensorType reducedType = builder.build();
+ TensorType reducedType = type(argument.type());
// Reduce cells
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index ec9b762a41c..6a9b8d68b38 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -26,6 +26,7 @@ public class Rename extends PrimitiveTensorFunction {
private final TensorFunction argument;
private final List<String> fromDimensions;
private final List<String> toDimensions;
+ private final Map<String, String> fromToMap;
public Rename(TensorFunction argument, String fromDimension, String toDimension) {
this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
@@ -43,13 +44,24 @@ public class Rename extends PrimitiveTensorFunction {
this.argument = argument;
this.fromDimensions = ImmutableList.copyOf(fromDimensions);
this.toDimensions = ImmutableList.copyOf(toDimensions);
+ this.fromToMap = fromToMap(fromDimensions, toDimensions);
+ }
+
+ public List<String> fromDimensions() { return fromDimensions; }
+ public List<String> toDimensions() { return toDimensions; }
+
+ private static Map<String, String> fromToMap(List<String> fromDimensions, List<String> toDimensions) {
+ Map<String, String> map = new HashMap<>();
+ for (int i = 0; i < fromDimensions.size(); i++)
+ map.put(fromDimensions.get(i), toDimensions.get(i));
+ return map;
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size());
return new Rename(arguments.get(0), fromDimensions, toDimensions);
@@ -59,11 +71,22 @@ public class Rename extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
+ public TensorType type(EvaluationContext context) {
+ return type(argument.type(context));
+ }
+
+ private TensorType type(TensorType type) {
+ TensorType.Builder builder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : type.dimensions())
+ builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
+ return builder.build();
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor tensor = argument.evaluate(context);
- Map<String, String> fromToMap = fromToMap();
- TensorType renamedType = rename(tensor.type(), fromToMap);
+ TensorType renamedType = type(tensor.type());
// an array which lists the index of each label in the renamed type
int[] toIndexes = new int[tensor.type().dimensions().size()];
@@ -82,13 +105,6 @@ public class Rename extends PrimitiveTensorFunction {
return builder.build();
}
- private TensorType rename(TensorType type, Map<String, String> fromToMap) {
- TensorType.Builder builder = new TensorType.Builder();
- for (TensorType.Dimension dimension : type.dimensions())
- builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
- return builder.build();
- }
-
private TensorAddress rename(TensorAddress address, int[] toIndexes) {
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)
@@ -102,13 +118,6 @@ public class Rename extends PrimitiveTensorFunction {
toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
}
- private Map<String, String> fromToMap() {
- Map<String, String> map = new HashMap<>();
- for (int i = 0; i < fromDimensions.size(); i++)
- map.put(fromDimensions.get(i), toDimensions.get(i));
- return map;
- }
-
private String toVectorString(List<String> elements) {
if (elements.size() == 1)
return elements.get(0);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index c856b548180..32cff5ac84a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -16,21 +16,21 @@ public class Softmax extends CompositeTensorFunction {
private final TensorFunction argument;
private final String dimension;
-
+
public Softmax(TensorFunction argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
-
+
public static TensorType outputType(TensorType inputType, String dimension) {
return Reduce.outputType(inputType, ImmutableList.of(dimension));
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size());
return new Softmax(arguments.get(0), dimension);
@@ -45,7 +45,7 @@ public class Softmax extends CompositeTensorFunction {
dimension),
ScalarFunctions.divide());
}
-
+
@Override
public String toString(ToStringContext context) {
return "softmax(" + argument.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 533a46f87fe..3f6dfae6222 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.MapEvaluationContext;
@@ -19,14 +20,14 @@ import java.util.List;
public abstract class TensorFunction {
/** Returns the function arguments of this node in the order they are applied */
- public abstract List<TensorFunction> functionArguments();
+ public abstract List<TensorFunction> arguments();
/**
* Returns a copy of this tensor function with the arguments replaced by the given list of arguments.
*
* @throws IllegalArgumentException if the argument list has the wrong size for this function
*/
- public abstract TensorFunction replaceArguments(List<TensorFunction> arguments);
+ public abstract TensorFunction withArguments(List<TensorFunction> arguments);
/**
* Translate this function - and all of its arguments recursively -
@@ -43,6 +44,13 @@ public abstract class TensorFunction {
*/
public abstract Tensor evaluate(EvaluationContext context);
+ /**
+ * Returns the type of the tensor this produces given the input types in the context
+ *
+ * @param context a context which must be passed to all nexted functions when evaluating
+ */
+ public abstract TensorType type(EvaluationContext context);
+
/** Evaluate with no context */
public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
index 2464be981f5..78ff0731566 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
@@ -14,7 +14,7 @@ public class XwPlusB extends CompositeTensorFunction {
private final TensorFunction x, w, b;
private final String dimension;
-
+
public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) {
this.x = x;
this.w = w;
@@ -23,10 +23,10 @@ public class XwPlusB extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); }
+ public List<TensorFunction> arguments() { return ImmutableList.of(x, w, b); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 3)
throw new IllegalArgumentException("XwPlusB must have 3 arguments, got " + arguments.size());
return new XwPlusB(arguments.get(0), arguments.get(1), arguments.get(2), dimension);
@@ -43,7 +43,7 @@ public class XwPlusB extends CompositeTensorFunction {
primitiveB,
ScalarFunctions.add());
}
-
+
@Override
public String toString(ToStringContext context) {
return "xw_plus_b(" + x.toString(context) + ", " +