Ленивая динамика - Алгоритмика
Ленивая динамика

Ленивая динамика

авторы Сергей Слотин Полина Романченко

Стандартный подход в динамическом программировании — создать массив для ответов на подзадачи и пройтись по нему циклом, пересчитывая неизвестные значения из известных. Однако, иногда сложно или вообще невозможно придумать такой порядок обхода, что все необходимые значения уже посчитаны.

В подобных случаях вместо циклов можно использовать подход, называемый мемоизацией: будем считать динамику рекурсивной функцией, в которой запоминается посчитанный результат и в следующий раз сразу возвращается, когда функция вызывается от тех же входных аргументов.

Простой пример

Решим в таком стиле задачу о нахождении $n$-ого числа Фибоначчи.

Изначально положим все $f[i] = -1$: это будет обозначать, что значение для соответствующего состояния еще не посчитано. Далее, положим $f[0] = 0$ и $f[1] = 1$ как базовые значения.

Теперь напишем функцию-переход, в которая просто в самом начале проверяет, было ли уже посчитано искомое значение — и если нет, то рекурсивно его считает.

int f[n] = {-1};
f[0] = 0;
f[1] = 1;

int g(int k) {
    if (f[k] == -1)
        f[k] = g(k - 2) + g(k - 1);
    return f[k];
}

Время работы так же составит $O(n)$, так как каждое значение мы считаем только один раз, но реальное время работы будет в несколько раз больше, потому что константа на вызовы функции значительно выше, чем на простой цикл.

Также можно заметить, что в такой динамике мы гарантированно посещаем только действительно нужные состояния, что в некоторых задачах приводит к более оптимальной асимптотике.

Кэширование

В более общем смысле, то, что мы делаем, называется кэшированием — запоминанием и переиспользованием промежуточных результатов. Это очень распространенная концепция в программировании, и современные языки программирования — особенно поддерживают — обычно поддерживают мемоизацию как встроенную оптимизацию.

Например, в Python есть декоратор cache, который делает ровно это:

from functools import cache

@cache
def f(n):
    if n == 0:
        return 0
    if n == 1:
        return 1
    return f(n - 2) + f(n - 1)

При этом этому декоратору не нужно знать о каких-либо ограничениях на $n$: он запоминает результаты для любых входных и выходных данных, которые можно положить как кортежи в хеш-таблицу.

Попробуем подобное реализовать на C++.

Сложный пример

Рассмотрим задачу «Игра финансистов».

Имеется массив $a_i$ из $n \le 4000$ чисел, и два игрока по очереди берут из него числа, пытаясь максимизировать разницу свой суммы и суммы оппонента. На каждом ходу игрок может взять сколько-то крайних чисел с одной стороны — первый игрок слева, второй справа. При этом, если игрок на предыдущем ходе взял $k$ чисел, то его оппонент на следующем ходе может взять либо $k$, либо $(k + 1)$ чисел (на первом ходе можно взять 1 или 2).

Игра завершается, когда игрок не может сделать ход.

Введем следующую динамику: $f[l, r, k, p] = $ максимальная достижимая разность сумм игрока $p$ и его оппонента, если остались только элементы с $l$ по $r$, и на предыдущем ходу было взято $k$ чисел. Так как у нас есть всего два выбора — брать $k$ или $(k+1)$ чисел — то переход это просто максимум из этих двух возможностей:

$$ f[l, r, k, 0] = \max \begin{cases} \sum_{i=l}^{k} a_i + f[l + k, r, k, 1] \\ \sum_{i=l}^{k + 1} a_i + f[l + k + 1, r, k + 1, 1] \end{cases} $$

(И аналогично для другого игрока $p=1$.)

Попытаемся оценить, за сколько такое работает. Пересчет можно делать за $O(1)$ с помощью единожды посчитанных префиксных сумм исходного массива, значит асимптотика равна числу состояний. Но тут вроде как возникает проблема:

$$ l \in [1, n] \\ r \in [1, n] \\ k \in [1, n] \\ p \in \{ 0, 1 \} $$

Если наивно хранить все состояния в четырехмерном массиве, то для него потребуется $O(n^3)$ ячеек, что нас не устраивает, так как $n=4000$. Оказывается, достижимых состояний сильно меньше.

Утверждение 1. $k = O(\sqrt n)$.

Чтобы получить данное $k$, нужно, чтобы на каких-то предыдущих шагах были взяты $k, (k-1), \ldots, 2, 1$ чисел, для чего размер массива $n$ должен быть не менее

$$ n \ge \sum_{i=1}^x i \approx \frac{k^2}{2} $$

Из чего следует, что $k = O(\sqrt n)$.

Утверждение 2. Для заданной левой границы $l$ существует не более $O(\sqrt n)$ возможных правых границ $r$.

Рассмотрим разность $d$ количеств чисел, взятых первым и вторым игроком:

$$ d = l - (n - r) $$

Так как один игрок берет сколько-то элементов, а другой после этого либо берет столько же, либо на один больше, то за любые два последовательных хода $d$ изменится не более, чем на единицу. Более того, чтобы $d$ изменилось на единицу, необходимо, чтобы $k$ тоже увеличилось. Как мы уже установили, $k = O(\sqrt n)$, а значит и $d = O(\sqrt n)$. Так как по паре $(l, d)$ из определения выше восстанавливается $r$, получаем, что для каждой $l$ различных $r$ тоже будет $O(\sqrt n)$.

Получается, что на самом деле у нас не $O(n^3)$, а $O(n \cdot \sqrt n \cdot \sqrt n \cdot 2) = O(n^2)$ достижимых состояний. Значит, можно либо придумать какой-нибудь более умный формат хранения динамики и более умный обход, либо применить мемоизацию.

Так как просто создать четырехмерный массив не влезает в память (даже если он на $O(n^{2.5})$ элементов), воспользуемся вместо этого хеш-таблицей. Проще всего определить её как unordered_map из int в int и найти какую-нибудь нумерацию кортежей $(l, r, k, p)$ в какой-нибудь помещающийся в int промежуток, и вместо тюпла из 4 чисел использовать этот номер.

const int N = 4000, K = 80;
 
int n;
int s[N + 1];
 
unordered_map<int, int> dp;
 
long f(int l, int r, int k, bool p) {
    if ((r - l + 1) < k)
        return 0;
    int key = l * N * K * 2 + r * K * 2 + k * 2 + p;
    if (dp.count(key))
        return dp[key];
    return dp[key] = (p ?
        max(
            s[l + k] - s[l] - f(l + k, r, k, 0),
            s[l + k + 1] - s[l] - f(l + k + 1, r, k + 1, 0)
        )
        :
        max(
            s[r + 1] - s[r - k + 1] - f(l, r - k, k, 1),
            s[r + 1] - s[r - k] - f(l, r - k - 1, k + 1, 1)
        )
    );
}
 
int main(){
    dp.rehash(6.2e7);
 
    cin >> n;
 
    for (int i = 0; i < n; i++) {
        cin >> s[i + 1];
        s[i + 1] += s[i];
    }
 
    cout << f(0, n - 1, 1, 1); 
 
    return 0;
}

Если можно заранее оценить число элементов в хеш-таблице, то выгодно в самом начале зарезервировать место под неё через dp.rehash: это обычно ускоряет решение в 2-3 раза. Если написать свою хеш-таблицу, будет ещё в 3-4 раза быстрее.