Каков наиболее эффективный способ получить пересечение k отсортированных массивов?

Учитывая k отсортированных массивов, каков наиболее эффективный способ получить пересечение этих списков

Пример

ВХОД:

[[1,3,5,7], [1,1,3,5,7], [1,4,7,9]] 

Выход:

[1,7]

Существует способ получить объединение k отсортированных массивов на основе того, что я прочитал в книге «Элементы программирования» за nlogk time. Мне было интересно, есть ли способ сделать что-то подобное и для перекрестка

## merge sorted arrays in nlogk time [ regular appending and merging is nlogn time ]
import heapq
def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]
    
    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))
    
    res = []
 
    # collect results in nlogK time
    while heap:
        elem, ary = heapq.heappop(heap)
        it = srtd_iters[ary]
        res.append(elem)
        nxt = next(it, None)
        if nxt:
            heapq.heappush(heap, (nxt, ary))

РЕДАКТИРОВАТЬ: очевидно, что это вопрос алгоритма, который я пытаюсь решить, поэтому я не могу использовать какие-либо встроенные функции, такие как набор пересечений и т. Д.

Ответов (13)

Решение

Да, это возможно! Для этого я изменил ваш пример кода.

В моем ответе предполагается, что ваш вопрос касается алгоритма - если вам нужен самый быстрый код, использующий set s, см. Другие ответы.

Это сохраняет O(n log(k)) временную сложность: весь код между if lowest != elem or ary != times_seen: и unbench_all = False есть O(log(k)) . Внутри основного цикла ( for unbenched in range(times_seen): ) есть вложенный цикл, но он выполняется только times_seen раз, и times_seen изначально равен 0 и сбрасывается в 0 после каждого запуска этого внутреннего цикла, и может увеличиваться только один раз за итерацию основного цикла, поэтому внутренний цикл не может сделать больше итераций, чем основной цикл. Таким образом, поскольку код внутри внутреннего цикла выполняется O(log(k)) и выполняется не более того же количества раз, что и внешний цикл, а внешний цикл - O(log(k)) и выполняется n раз, алгоритм таков O(n log(k)) .

Этот алгоритм зависит от того, как кортежи сравниваются в Python. Он сравнивает первые элементы кортежей и, если они равны, сравнивает вторые элементы (т.е. (x, a) < (x, b) истинно тогда и только тогда a < b ). В этом алгоритме, в отличие от примера кода в вопросе, когда элемент извлекается из кучи, он не обязательно повторно выталкивается в той же итерации. Поскольку нам нужно проверить, все ли подсписки содержат один и тот же номер, после того, как номер извлекается из кучи, этот подсписок - это то, что я называю «Benched», что означает, что он не добавляется обратно в кучу. Это потому, что нам нужно проверить, содержат ли другие подсписки тот же элемент, поэтому добавлять следующий элемент этого подсписка прямо сейчас не нужно.

Если номер действительно присутствует во всех подсписках, то куча будет выглядеть примерно так [(2,0),(2,1),(2,2),(2,3)], со всеми первыми элементами кортежей одинаковыми, поэтому heappop будет выбран тот, у которого индекс подсписка самый низкий. Это означает, что первый индекс 0 будет извлечен и times_seen увеличен до 1, затем индекс 1 будет извлечен и times_seen увеличен до 2 - если ary не равен, times_seen то число не находится на пересечении всех подсписок. Это приводит к условию if lowest != elem or ary != times_seen:, которое определяет, когда число не должно быть в результате. else Ветвь этого if утверждения, когда он все еще может быть в результате.

unbench_all Логическое значение для когда все суб-листы должны быть удалены со скамейки - это может быть:

  1. Известно, что текущий номер не находится на пересечении подсписок
  2. Известно, что он находится на пересечении подсписок

Когда unbench_all есть True, все подсписки, которые были удалены из кучи, добавляются заново. Известно , что это те , с индексами в range(times_seen) так как алгоритм удаляет элементы из кучи , только если они имеют один и тот же номер, поэтому они должны быть удалены в порядке индекса, смежно и , начиная с индекса 0, и должны быть times_seen в их. Это означает, что нам не нужно хранить индексы выбранных подсписок, а только их количество.

import heapq


def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # the number of tims that the current number has been seen
    times_seen = 0

    # the lowest number from the heap - currently checking if the first numbers in all sub-lists are equal to this
    lowest = heap[0][0] if heap else None

    # collect results in nlogK time
    while heap:
        elem, ary = heap[0]
        unbench_all = True

        if lowest != elem or ary != times_seen:
            if lowest == elem:
                heapq.heappop(heap)
                it = srtd_iters[ary]
                nxt = next(it, None)
                if nxt:
                    heapq.heappush(heap, (nxt, ary))
        else:
            heapq.heappop(heap)
            times_seen += 1

            if times_seen == len(srtd_arys):
                res.append(elem)
            else:
                unbench_all = False

        if unbench_all:
            for unbenched in range(times_seen):
                unbenched_it = srtd_iters[unbenched]
                nxt = next(unbenched_it, None)
                if nxt:
                    heapq.heappush(heap, (nxt, unbenched))
            times_seen = 0
            if heap:
                lowest = heap[0][0]

    return res


if __name__ == '__main__':
    a1 = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]
    a2 = [[1, 1], [1, 1, 2, 2, 3]]
    for arys in [a1, a2]:
        print(mergeArys(arys))

Эквивалентный алгоритм можно записать так, если хотите:

def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # collect results in nlogK time
    while heap:
        elem, ary = heap[0]
        lowest = elem
        keep_elem = True
        for i in range(len(srtd_arys)):
            elem, ary = heap[0]
            if lowest != elem or ary != i:
                if ary != i:
                    heapq.heappop(heap)
                    it = srtd_iters[ary]
                    nxt = next(it, None)
                    if nxt:
                        heapq.heappush(heap, (nxt, ary))

                keep_elem = False
                i -= 1
                break
            heapq.heappop(heap)

        if keep_elem:
            res.append(elem)

        for unbenched in range(i+1):
            unbenched_it = srtd_iters[unbenched]
            nxt = next(unbenched_it, None)
            if nxt:
                heapq.heappush(heap, (nxt, unbenched))

        if len(heap) < len(srtd_arys):
            heap = []

    return res

Использование порядка сортировки

Вот подход O (n), который не требует каких-либо специальных структур данных или вспомогательной памяти, помимо фундаментального требования одного итератора и одного значения на подсписок:

from itertools import cycle

def intersection(data):
    ITERATOR, VALUE = 0, 1
    n = len(data)
    result = []
    try:
        pairs = cycle([(it := iter(sublist)), next(it)] for sublist in data)
        pair = next(pairs)
        curr = pair[VALUE]  # Candidate is the largest value seen so far
        matches = 1         # Number of pairs where the candidate occurs
        while True:
            iterator, value = pair = next(pairs)
            while value < curr:
                value = next(iterator)
            pair[VALUE] = value
            if value > curr:
                curr, matches = value, 1
                continue
            matches += 1
            if matches != n:
                continue
            result.append(curr)
            while (value := next(iterator)) == curr:
                pass
            pair[VALUE] = value
            curr, matches = value, 1
    except StopIteration:
        return result

Вот пример сеанса:

>>> data = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
>>> intersection(data)
[1, 7]

Алгоритм на словах

Алгоритм перебирает итератор, пары значений. Если значение совпадает во всех парах, оно принадлежит пересечению. Если значение ниже, чем любое другое, наблюдаемое до сих пор, текущий итератор продвигается. Если значение больше, чем любое из наблюдаемых до сих пор, оно становится новой целью, и счетчик совпадений сбрасывается до единицы. Когда любой итератор исчерпан, алгоритм завершен.

Не зависит от встроенных функций

Использование itertools.cycle () совершенно необязательно. Это легко эмулировать, увеличивая индекс, который оборачивается в конце.

Вместо того:

iterator, value = pair = next(pairs)

Вы могли написать:

pairnum += 1
if pairnum == n:
    pairnum = 0
iterator, value = pair = pairs[pairnum]    

Или более компактно:

pairnum = (pairnum + 1) % n
iterator, value = pair = pairs[pairnum] 

Повторяющиеся значения

Если повторы должны сохраняться (например, мультимножество), это простая модификация, просто измените четыре строки после, result.append(curr) чтобы удалить соответствующий элемент из каждого итератора:

def intersection(data):
    ITERATOR, VALUE = 0, 1
    n = len(data)
    result = []
    try:
        pairs = cycle([(it := iter(sublist)), next(it)] for sublist in data)
        pair = next(pairs)
        curr = pair[VALUE]  # Candidate is the largest value seen so far
        matches = 1         # Number of pairs where the candidate occurs
        while True:
            iterator, value = pair = next(pairs)
            while value < curr:
                value = next(iterator)
            pair[VALUE] = value
            if value > curr:
                curr, matches = value, 1
                continue
            matches += 1
            if matches != n:
                continue
            result.append(curr)
            for i in range(n):
                iterator, value = pair = next(pairs)
                pair[VALUE] = next(iterator)
            curr, matches = pair[VALUE], 1
    except StopIteration:
        return result

Вы можете использовать reduce :

from functools import reduce

a = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]] 
reduce(lambda x, y: x & set(y), a[1:], set(a[0]))
 {1, 7}

Я придумал этот алгоритм. Оно не превышает O (n k). Не знаю, достаточно ли оно для вас. суть этого алгоритма в том, что вы можете иметь k индексов для каждого массива, и на каждой итерации вы находите индексы следующего элемента в пересечении и увеличиваете каждый индекс до тех пор, пока вы не превысите границы массива и в пересечении больше не останется элементов . Хитрость заключается в том, что поскольку массивы отсортированы, вы можете смотреть на два элемента в двух разных массивах, и если один из них больше другого, вы можете немедленно выбросить другой, потому что вы знаете, что у вас не может быть меньшего числа, чем тот, на который вы смотрите. наихудший случай этого алгоритма состоит в том, что каждый индекс будет увеличен до границы, что займет k n времени, поскольку индекс не может уменьшить свое значение.

  inter = []

  for n in range(len(arrays[0])):
    if indexes[0] >= len(arrays[0]):
        return inter
    for i in range(1,k):
      if indexes[i] >= len(arrays[i]):
        return inter
      while indexes[i] < len(arrays[i]) and arrays[i][indexes[i]] < arrays[0][indexes[0]]:
        indexes[i] += 1
      while indexes[i] < len(arrays[i]) and indexes[0] < len(arrays[0]) and arrays[i][indexes[i]] > arrays[0][indexes[0]]:
        indexes[0] += 1
    if indexes[0] < len(arrays[0]):
      inter.append(arrays[0][indexes[0]])
    indexes = [idx+1 for idx in indexes]
  return inter

Вы сказали, что мы не можем использовать наборы, но как насчет dicts / хеш-таблиц? (да, я знаю, что это в основном одно и то же): D

Если да, то вот довольно простой подход (прошу прощения за синтаксис py2):

arrays = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
counts = {}

for ar in arrays:
  last = None
  for i in ar:
    if (i != last):
      counts[i] = counts.get(i, 0) + 1
    last = i

N = len(arrays)
intersection = [i for i, n in counts.iteritems() if n == N]
print intersection

То же, что и решение Раймонда Хеттингера, но с более простым кодом Python:

def intersection(arrays, unique: bool=False):
    result = []
    if not len(arrays) or any(not len(array) for array in arrays):
        return result

    pointers = [0] * len(arrays)

    target = arrays[0][0]
    start_step = 0
    current_step = 1
    while True:
        idx = current_step % len(arrays)
        array = arrays[idx]

        while pointers[idx] < len(array) and array[pointers[idx]] < target:
            pointers[idx] += 1

        if pointers[idx] < len(array) and array[pointers[idx]] > target:
            target = array[pointers[idx]]
            start_step = current_step
            current_step += 1
            continue

        if unique:
            while (
                pointers[idx] + 1 < len(array)
                and array[pointers[idx]] == array[pointers[idx] + 1]
            ):
                pointers[idx] += 1

        if (current_step - start_step) == len(arrays):
            result.append(target)
            for other_idx, other_array in enumerate(arrays):
                pointers[other_idx] += 1
            if pointers[idx] < len(array):
                target = array[pointers[idx]]
                start_step = current_step

        if pointers[idx] == len(array):
            return result

        current_step += 1

Вот ответ O (n) (где n = sum(len(sublist) for sublist in data) ).

from itertools import cycle

def intersection(data):
    result = []    
    maxval = float("-inf")
    consecutive = 0
    try:
        for sublist in cycle(iter(sublist) for sublist in data):

            value = next(sublist)
            while value < maxval:
                value = next(sublist)

            if value > maxval:
                maxval = value
                consecutive = 0
                continue

            consecutive += 1
            if consecutive >= len(data)-1:
                result.append(maxval)
                consecutive = 0

    except StopIteration:
        return result

print(intersection([[1,3,5,7], [1,1,3,5,7], [1,4,7,9]]))

[1, 7]

Некоторые из вышеперечисленных методов не охватывают примеры, когда есть дубликаты в каждом подмножестве списка. Код ниже реализует это пересечение, и он будет более эффективным, если в подмножестве списка будет много дубликатов :) Если не уверены в дубликатах, рекомендуется использовать счетчик из коллекций from collections import Counter . Функция настраиваемого счетчика предназначена для повышения эффективности обработки больших дубликатов. Но все равно не может превзойти реализацию Раймона Хеттингера.

def counter(my_list):
    my_list = sorted(my_list)
    first_val, *all_val = my_list
    p_index = my_list.index(first_val)
    my_counter = {}
    for item in all_val:
         c_index = my_list.index(item)
         diff = abs(c_index-p_index)
         p_index = c_index
         my_counter[first_val] = diff 
         first_val = item
    c_index = my_list.index(item)
    diff = len(my_list) - c_index
    my_counter[first_val] = diff 
    return my_counter

def my_func(data):
    if not data or not isinstance(data, list):
        return
    # get the first value
    first_val, *all_val = data
    if not isinstance(first_val, list):
        return
    # count items in first value
    p = counter(first_val) # counter({1: 2, 3: 1, 5: 1, 7: 1})
    # collect all common items and calculate the minimum occurance in intersection
    for val in all_val:
        # collecting common items
        c = counter(val)
        # calculate the minimum occurance in intersection
        inner_dict = {}
        for inner_val in set(c).intersection(set(p)):
            inner_dict[inner_val] = min(p[inner_val], c[inner_val])
        p = inner_dict
    # >>>p
    # {1: 2, 7: 1}
    # Sort by keys of counter
    sorted_items = sorted(p.items(), key=lambda x:x[0]) # [(1, 2), (7, 1)]
    result=[i[0] for i in sorted_items for _ in range(i[1])] # [1, 1, 7]
    return result

Вот примеры примеров

>>> data = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
>>> my_func(data=data)
[1, 7]
>>> data = [[1,1,3,5,7],[1,1,3,5,7],[1,1,4,7,9]]
>>> my_func(data=data)
[1, 1, 7]

Вы можете сделать следующее, используя функции heapq.merge , chain.from_iterable и groupby.

from heapq import merge
from itertools import groupby, chain

ls = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]


def index_groups(lst):
    """[1, 1, 3, 5, 7] -> [(1, 0), (1, 1), (3, 0), (5, 0), (7, 0)]"""
    return chain.from_iterable(((e, i) for i, e in enumerate(group)) for k, group in groupby(lst))


iterables = (index_groups(li) for li in ls)
flat = merge(*iterables)
res = [k for (k, _), g in groupby(flat) if sum(1 for _ in g) == len(ls)]
print(res)

Выход

[1, 7]

Идея состоит в том, чтобы дать дополнительное значение (используя enumerate), чтобы различать равные значения в одном списке (см. Функцию index_groups ).

Сложность этого алгоритма в том, O(n) где n - сумма длин каждого списка во входных данных.

Обратите внимание, что вывод для (дополнительный 1 en в каждом списке):

ls = [[1, 1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 1, 4, 7, 9]]

является:

[1, 1, 7]

Вот алгоритм подсчета за один проход, упрощенная версия того, что предлагали другие.

def intersection(iterables):
    target, count = None, 0
    for it in itertools.cycle(map(iter, iterables)):
        for value in it:
            if count == 0 or value > target:
                target, count = value, 1
                break
            if value == target:
                count += 1
                break
        else:  # exhausted iterator
            return
        if count >= len(iterables):
            yield target
            count = 0

Двоичный и экспоненциальный поиск еще не появились. Их легко воссоздать даже с ограничением «no builtins».

На практике это было бы намного быстрее и сублинейно. В худшем случае - когда перекресток не сокращается - наивный подход повторит работу. Но для этого есть решение: интегрировать двоичный поиск при разделении массивов пополам.

def intersection(seqs):
    seq = min(seqs, key=len)
    if not seq:
        return
    pivot = seq[len(seq) // 2]
    lows, counts, highs = [], [], []
    for seq in seqs:
        start = bisect.bisect_left(seq, pivot)
        stop = bisect.bisect_right(seq, pivot, start)
        lows.append(seq[:start])
        counts.append(stop - start)
        highs.append(seq[stop:])
    yield from intersection(lows)
    yield from itertools.repeat(pivot, min(counts))
    yield from intersection(highs)

Оба обрабатывают дубликаты. Оба гарантируют время наихудшего случая O (N) (считая срез как атомарный). Последняя приблизится к скорости O (min_size); всегда разделяя наименьшее пополам, он, по сути, не может пострадать от неравномерного разделения.

Вы можете использовать битовую маску с горячим кодированием. Внутренние списки становятся maxterms. Вы и они вместе для пересечения и / или они для союза. Затем вам нужно конвертировать обратно, для чего я немного взломал .

problem = [[1,3,5,7],[1,1,3,5,8,7],[1,4,7,9]];

debruijn = [0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
    31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9];
u32 = accum = (1 << 32) - 1;
for vec in problem:
    maxterm = 0;
    for v in vec:
        maxterm |= 1 << v;
    accum &= maxterm;

# https://graphics.stanford.edu/~seander/bithacks.html#IntegerLogDeBruijn
result = [];
while accum:
    power = accum;
    accum &= accum - 1; # Peter Wegner CACM 3 (1960), 322
    power &= ~accum;
    result.append(debruijn[((power * 0x077CB531) & u32) >> 27]);

print result;

При этом используются (имитируются) 32-битные целые числа, поэтому вы можете иметь их только [0, 31] в своих наборах.

* Я неопытен в Python, поэтому рассчитал время. Обязательно нужно использовать set.intersection .

O (n), но Set в 5,5 раз быстрее.

Вы можете использовать встроенные наборы и наборы пересечений:

d = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]] 
result = set(d[0]).intersection(*d[1:])
{1, 7}

Я не мог не заметить, что это, похоже, разновидность проблемы «Жулик благосостояния»; см. книгу Дэвида Гриса Наука программирования . Эдсгер Дейкстра также написал об этом EWD, см. Восходящие функции и Социальный мошенник .

Жулик благосостояния

Предположим, у нас есть три длинные магнитные ленты, каждая из которых содержит список имен в алфавитном порядке:

  • все люди, работающие в IBM Yorktown
  • студенты Колумбийского университета
  • люди на пособие в Нью-Йорке

Практически все три списка бесконечны, поэтому верхние границы не приводятся. Известно, что хотя бы один человек числится во всех трех списках. Напишите программу, чтобы найти первого такого человека.

Наше пересечение проблемы упорядоченных списков является обобщением проблемы Крука благосостояния.

Вот (довольно примитивное?) Решение Python проблемы Welfare Crook:

def find_welfare_crook(f, g, h, i, j, k):
    """f, g, and h are "ascending functions," i.e.,
i <= j implies f[i] <= f[j] or, equivalently,
f[i] < f[j] implies i < j, and the same goes for g and h.
i, j, k define where to start the search in each list.
"""
    # This is an implementation of a solution to the Welfare Crook
    # problems presented in David Gries's book, The Science of Programming.
    # The surprising and beautiful thing is that the guard predicates are
    # so few and so simple.
    i , j , k = i , j , k
    while True:
        if f[i] < g[j]:
            i += 1
        elif g[j] < h[k]:
            j += 1
        elif h[k] < f[i]:
            k += 1
        else:
            break
    return (i,j,k)
    # The other remarkable thing is how the negation of the guard
    # predicates works out to be:  f[i] == g[j] and g[j] == c[k].

Обобщение на пересечение K списков

Это обобщается на списки K , и вот что я придумал; Я не знаю, насколько это Pythonic, но он довольно компактен:

def findIntersectionLofL(lofl):
    """Generalized findIntersection function which operates on a "list of lists." """
    K = len(lofl)
    indices = [0 for i in range(K)]
    result = []
    #
    try:
        while True:
            # idea is to maintain the indices via a construct like the following:
            allEqual = True
            for i in range(K):
                if lofl[i][indices[i]] < lofl[(i+1)%K][indices[(i+1)%K]] :
                    indices[i] += 1
                    allEqual = False
            # When the above iteration finishes, if all of the list
            # items indexed by the indices are equal, then another
            # item common to all of the lists must be added to the result.
            if allEqual :
                result.append(lofl[0][indices[0]])
                while lofl[0][indices[0]] == lofl[1][indices[1]]:
                    indices[0] += 1
    except IndexError as e:
        # Eventually, the foregoing iteration will advance one of the
        # indices past the end of one of the lists, and when that happens
        # an IndexError exception will be raised.  This means the algorithm
        # is finished.
        return result

Это решение не сохраняет повторяющиеся элементы. Изменение программы для включения всех повторяющихся элементов путем изменения того, что программа делает в условном выражении в конце цикла «while True», - это упражнение, оставленное читателю.

Улучшенная производительность

Комментарии от @greybeard вызвали уточнения, показанные ниже, в предварительном вычислении «модулей индекса массива» (выражения «(i + 1)% K»), а дальнейшее исследование также привело к изменениям во внутренней структуре итерации с целью дальнейшего удаления накладные расходы:

def findIntersectionLofLunRolled(lofl):
    """Generalized findIntersection function which operates on a "list of lists."
Accepts a list-of-lists, lofl.  Each of the lists must be ordered.
Returns the list of each element which appears in all of the lists at least once.
"""
    K = len(lofl)
    indices = [0] * K
    result = []
    lt = [ (i, (i+1) % K) for i in range(K) ] # avoids evaluation of index exprs inside the loop
    #
    try:
        while True:
            allUnEqual = True
            while allUnEqual:
                allUnEqual = False
                for i,j in lt:
                    if lofl[i][indices[i]] < lofl[j][indices[j]]:
                        indices[i] += 1
                        allUnEqual = True
            # Now all of the lofl[i][indices[i]], for all i, are the same value.
            # Store that value in the result, and then advance all of the indices
            # past that common value:
            v = lofl[0][indices[0]]
            result.append(v)
            for i,j in lt:
                while lofl[i][indices[i]] == v:
                    indices[i] += 1
    except IndexError as e:
        # Eventually, the foregoing iteration will advance one of the
        # indices past the end of one of the lists, and when that happens
        # an IndexError exception will be raised.  This means the algorithm
        # is finished.
        return result