diff options
Diffstat (limited to 'vespajlib/src')
12 files changed, 84 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 0727579a331..153a3f896de 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -27,6 +27,13 @@ public class ConstantTensor extends PrimitiveTensorFunction { public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("ConstantTensor must have 0 arguments, got " + arguments.size()); + return this; + } + + @Override public PrimitiveTensorFunction toPrimitive() { return this; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index c0e5776bf48..013a95fe51f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -44,6 +44,13 @@ public class Generate extends PrimitiveTensorFunction { public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size()); + return this; + } + + @Override public PrimitiveTensorFunction toPrimitive() { return this; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 323da5906c3..ce1f123a216 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -44,6 +44,13 @@ public class Join extends PrimitiveTensorFunction { public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); } @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 2) + throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size()); + return new Join(arguments.get(0), arguments.get(1), combinator); + } + + @Override public PrimitiveTensorFunction toPrimitive() { return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index 4467b378b3f..2e61792aa90 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -20,6 +20,13 @@ public class L1Normalize extends CompositeTensorFunction { public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("L1Normalize must have 1 argument, got " + arguments.size()); + return new L1Normalize(arguments.get(0), dimension); + } + + @Override public PrimitiveTensorFunction toPrimitive() { TensorFunction primitiveArgument = argument.toPrimitive(); // join(x, reduce(x, "avg", "dimension"), f(x,y) (x / y)) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index 0e96b43bd22..40d1b2a95c1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -20,6 +20,13 @@ public class L2Normalize extends CompositeTensorFunction { public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("L2Normalize must have 1 argument, got " + arguments.size()); + return new L2Normalize(arguments.get(0), dimension); + } + + @Override public PrimitiveTensorFunction toPrimitive() { TensorFunction primitiveArgument = argument.toPrimitive(); return new Join(primitiveArgument, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 5db88953c64..c1b148ff82f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -34,6 +34,13 @@ public class Map extends PrimitiveTensorFunction { public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("Map must have 1 argument, got " + arguments.size()); + return new Map(arguments.get(0), mapper); + } + + @Override public PrimitiveTensorFunction toPrimitive() { return new Map(argument.toPrimitive(), mapper); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 4492ab083d4..8a6622213e5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -22,6 +22,13 @@ public class Matmul extends CompositeTensorFunction { public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); } @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 2) + throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size()); + return new Matmul(arguments.get(0), arguments.get(1), dimension); + } + + @Override public PrimitiveTensorFunction toPrimitive() { TensorFunction primitiveArgument1 = argument1.toPrimitive(); TensorFunction primitiveArgument2 = argument2.toPrimitive(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index ef18cb61b17..e6f9874c0bd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -63,6 +63,13 @@ public class Reduce extends PrimitiveTensorFunction { public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size()); + return new Reduce(arguments.get(0), aggregator, dimensions); + } + + @Override public PrimitiveTensorFunction toPrimitive() { return new Reduce(argument.toPrimitive(), aggregator, dimensions); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 05af86c33e8..0995e56eb9a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -44,6 +44,13 @@ public class Rename extends PrimitiveTensorFunction { public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size()); + return new Rename(arguments.get(0), fromDimensions, toDimensions); + } + + @Override public PrimitiveTensorFunction toPrimitive() { return this; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index b05b8172b42..713452d55d2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -20,6 +20,13 @@ public class Softmax extends CompositeTensorFunction { public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size()); + return new Softmax(arguments.get(0), dimension); + } + + @Override public PrimitiveTensorFunction toPrimitive() { TensorFunction primitiveArgument = argument.toPrimitive(); return new Join(new Map(primitiveArgument, ScalarFunctions.exp()), diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index a717292632e..34ccf0704ca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -17,6 +17,13 @@ public abstract class TensorFunction { public abstract List<TensorFunction> functionArguments(); /** + * Returns a copy of this tensor function with the arguments replaced by the given list of arguments. + * + * @throws IllegalArgumentException if the argument list has the wrong size for this function + */ + public abstract TensorFunction replaceArguments(List<TensorFunction> arguments); + + /** * Translate this function - and all of its arguments recursively - * to a tree of primitive functions only. * diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java index 1988c1d2390..e83a514bd13 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -23,6 +23,13 @@ public class XwPlusB extends CompositeTensorFunction { public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); } @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 3) + throw new IllegalArgumentException("XwPlusB must have 3 arguments, got " + arguments.size()); + return new XwPlusB(arguments.get(0), arguments.get(1), arguments.get(2), dimension); + } + + @Override public PrimitiveTensorFunction toPrimitive() { TensorFunction primitiveX = x.toPrimitive(); TensorFunction primitiveW = w.toPrimitive(); |