// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.function.DoubleBinaryOperator; /** * The merge tensor operation produces from two argument tensors having equal types * a tensor having the same type where the values are the union of the values of both tensors. In the cases where both * tensors contain a value for a given cell, and only then, the lambda scalar expression is evaluated to produce * the resulting cell value. * * @author bratseth */ public class Merge extends PrimitiveTensorFunction { private final TensorFunction argumentA, argumentB; private final DoubleBinaryOperator merger; public Merge(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator merger) { Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); Objects.requireNonNull(merger, "The merger function cannot be null"); this.argumentA = argumentA; this.argumentB = argumentB; this.merger = merger; } /** Returns the type resulting from applying Merge to the two given types */ public static TensorType outputType(TensorType a, TensorType b) { return TypeResolver.merge(a, b); } public DoubleBinaryOperator merger() { return merger; } @Override public List> arguments() { return List.of(argumentA, argumentB); } @Override public TensorFunction withArguments(List> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Merge must have 2 arguments, got " + arguments.size()); return new Merge<>(arguments.get(0), arguments.get(1), merger); } @Override public PrimitiveTensorFunction toPrimitive() { return new Merge<>(argumentA.toPrimitive(), argumentB.toPrimitive(), merger); } @Override public TensorType type(TypeContext context) { return outputType(argumentA.type(context), argumentB.type(context)); } @Override public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType mergedType = outputType(a.type(), b.type()); return evaluate(a, b, mergedType, merger); } @Override public String toString(ToStringContext context) { return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")"; } @Override public int hashCode() { return Objects.hash("merge", argumentA, argumentB, merger); } static Tensor evaluate(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) { // Choose merge algorithm if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) return indexedVectorMerge((IndexedTensor)a, (IndexedTensor)b, mergedType, combinator); else return generalMerge(a, b, mergedType, combinator); } private static boolean hasSingleIndexedDimension(Tensor tensor) { return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); } private static Tensor indexedVectorMerge(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) { long aSize = a.dimensionSizes().size(0); long bSize = b.dimensionSizes().size(0); long mergedSize = Math.max(aSize, bSize); long sharedSize = Math.min(aSize, bSize); Iterator aIterator = a.valueIterator(); Iterator bIterator = b.valueIterator(); IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); for (long i = 0; i < sharedSize; i++) builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i); Iterator largestIterator = aSize > bSize ? aIterator : bIterator; for (long i = sharedSize; i < mergedSize; i++) builder.cell(largestIterator.next(), i); return builder.build(); } private static Tensor generalMerge(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) { Tensor.Builder builder = Tensor.Builder.of(mergedType); addCellsOf(a, b, builder, combinator); addCellsOf(b, a, builder, null); return builder.build(); } private static void addCellsOf(Tensor a, Tensor b, Tensor.Builder builder, DoubleBinaryOperator combinator) { for (Iterator i = a.cellIterator(); i.hasNext(); ) { Map.Entry aCell = i.next(); var key = aCell.getKey(); if (! b.has(key)) { builder.cell(key, aCell.getValue()); } else if (combinator != null) { builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); } } } }