¿Memoización en Haskell?


Cualquier punteros sobre cómo resolver eficientemente la siguiente función en Haskell, para números grandes (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

He visto ejemplos de memoización en Haskell para resolver fibonacci números, que implicaba computar (perezosamente) todos los números de fibonacci hasta el n requerido. Pero en este caso, para un n dado, solo necesitamos calcule muy pocos resultados intermedios.

Gracias

Author: Chetan, 2010-07-09

8 answers

Podemos hacer esto muy eficientemente haciendo una estructura que podemos indexar en tiempo sub-lineal.

Pero primero,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

Definamos f, pero hagamos que use 'recursión abierta' en lugar de llamarse a sí misma directamente.

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

Puede obtener un f sin memoizar usando fix f

Esto le permitirá probar que f hace lo que quiere decir para valores pequeños de f llamando, por ejemplo: fix f 123 = 144

Podríamos recordar esto por definiendo:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

Que funciona pasablemente bien, y reemplaza lo que iba a tomar O(n^3) tiempo con algo que recuerda los resultados intermedios.

Pero todavía toma tiempo lineal solo indexar para encontrar la respuesta memoizada para mf. Esto significa que los resultados son:

*Main Data.List> faster_f 123801
248604

Son tolerables, pero el resultado no escala mucho mejor que eso. Podemos hacerlo mejor!

Primero, definamos un árbol infinito: {[24]]}

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

Y entonces defina una forma de indexar en él, para que podamos encontrar un nodo con index n en O(log n) tiempo en su lugar:

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

... y podemos encontrar que un árbol lleno de números naturales es conveniente para que no tengamos que jugar con esos índices:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

Ya que podemos indexar, solo puede convertir un árbol en una lista:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

Puedes comprobar el trabajo hasta ahora verificando que toList nats te da [0..]

Ahora,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

Funciona igual que con la lista anterior, pero en lugar de tomar tiempo lineal para encontrar cada nodo, puede perseguirlo en tiempo logarítmico.

El resultado es considerablemente más rápido:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

De hecho, es mucho más rápido que puede pasar y reemplazar Int con Integer anterior y obtener respuestas ridículamente grandes casi instantáneamente

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
 229
Author: Edward KMETT,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-11-28 06:07:45

La respuesta de Edward es una gema tan maravillosa que la he duplicado y proporcionado implementaciones de combinadores memoList y memoTree que recuerdan una función en forma recursiva abierta.

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f
 16
Author: Tom Ellis,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-05-23 12:26:36

No es la forma más eficiente, pero recuerda:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

Al solicitar f !! 144, se comprueba que f !! 143 existe, pero su valor exacto no se calcula. Todavía está establecido como un resultado desconocido de un cálculo. Los únicos valores exactos calculados son los necesarios.

Así que inicialmente, en cuanto a cuánto se ha calculado, el programa no sabe nada.

f = .... 

Cuando hacemos la solicitud f !! 12, comienza a hacer alguna coincidencia de patrones:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Ahora comienza a calcular

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

Esto recursivamente hace otra demanda en f, por lo que calculamos

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

Ahora podemos escurrir de nuevo algunos{[22]]}

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

Lo que significa que el programa ahora sabe:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Continuando con el goteo:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

Lo que significa que el programa ahora sabe:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Ahora continuamos con nuestro cálculo de f!!6:

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

Lo que significa que el programa ahora sabe:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Ahora continuamos con nuestro cálculo de f!!12:

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

Lo que significa que el programa ahora sabe:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

Así que el cálculo se hace bastante perezoso. El programa sabe que existe algún valor para f !! 8, que es igual a g 8, pero no tiene idea de lo que es g 8.

 12
Author: rampion,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2010-07-08 22:13:58

Como se indica en la respuesta de Edward Kmett, para acelerar las cosas, necesita almacenar en caché cálculos costosos y poder acceder a ellos rápidamente.

Para mantener la función no monádica, la solución de construir un árbol perezoso infinito, con una forma adecuada de indexarlo (como se muestra en publicaciones anteriores) cumple ese objetivo. Si renuncia a la naturaleza no monádica de la función, puede usar los contenedores asociativos estándar disponibles en Haskell en combinación con mónadas "tipo estado" (como Estado o ST).

Si bien el principal inconveniente es que obtiene una función no monádica, ya no tiene que indexar la estructura usted mismo, y solo puede usar implementaciones estándar de contenedores asociativos.

Para hacerlo, primero debe reescribir su función para aceptar cualquier tipo de mónada:

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <$> [2, 3, 4]
   return $ max n (sum recs)

Para sus pruebas, todavía puede definir una función que no hace memoización usando Datos.Función.fix, aunque es un poco más detallado:

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

Luego puede usar la mónada de estado en combinación con Datos.Mapa para acelerar las cosas:

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

Con cambios menores, puede adaptar el código a trabajos con Datos.HashMap en su lugar:

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

En lugar de estructuras de datos persistentes, también puede probar estructuras de datos mutables (como los Datos.HashTable) en combinación con la mónada ST:

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

En comparación con la implementación sin ninguna memoización, cualquiera de estas implementaciones le permite, para grandes entradas, obtener resultados en microsegundos en lugar de tener que esperar varios segundo.

Utilizando el Criterio como punto de referencia, pude observar que la implementación con los Datos.HashMap en realidad tuvo un rendimiento ligeramente mejor (alrededor del 20%) que los Datos.Mapa y Datos.HashTable para el que los tiempos eran muy similares.

Encontré los resultados del benchmark un poco sorprendentes. Mi sensación inicial era que la HashTable superaría a la implementación de HashMap porque es mutable. Puede haber algún defecto de rendimiento oculto en esta última implementación.

 8
Author: Quentin,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2015-05-16 08:48:22

Esta es una adición a la excelente respuesta de Edward Kmett.

Cuando probé su código, las definiciones de nats y index parecían bastante misteriosas, así que escribí una versión alternativa que encontré más fácil de entender.

Defino index y nats en términos de index' y nats'.

index' t n se define en el intervalo [1..]. (Recordemos que index t se define en el rango [0..].) Funciona busca en el árbol tratando n como una cadena de bits, y leyendo a través de la bits en reversa. Si el bit es 1, toma la rama derecha. Si el bit es 0, toma la rama izquierda. Se detiene cuando alcanza el último bit (que debe ser un 1).

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

Así como nats se define para index de modo que index nats n == n siempre es verdadero, nats' se define para index'.

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

Ahora, nats y index son simplemente nats' y index' pero con los valores cambiados por 1:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
 7
Author: Pitarou,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2012-06-16 04:59:35

Un par de años más tarde, miré esto y me di cuenta de que hay una manera simple de recordar esto en tiempo lineal usando zipWith y una función auxiliar:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilate tiene la propiedad handy que dilate n xs !! i == xs !! div i n.

Entonces, suponiendo que se nos da f (0), esto simplifica el cálculo a

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

Se parece mucho a nuestra descripción del problema original, y dar una solución lineal (sum $ take n fs tomará O(n)).

 3
Author: rampion,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2015-10-11 03:03:57

Otro anexo a la respuesta de Edward Kmett: un ejemplo autónomo:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

Úselo de la siguiente manera para recordar una función con un único entero arg (por ejemplo, fibonacci):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

Solo se almacenarán en caché los valores de los argumentos no negativos.

Para almacenar en caché también los valores de los argumentos negativos, utilice memoInt, definido de la siguiente manera:

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

Para almacenar en caché los valores de las funciones con dos argumentos enteros use memoIntInt, definido de la siguiente manera:

memoIntInt f = memoInt (\n -> memoInt (f n))
 2
Author: Neal Young,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2014-01-26 01:38:28

Una solución sin indexación, y no basada en la de Edward KMETT.

Factorizo los subárboles comunes a un padre común (f(n/4) se comparte entre f(n/2) y f(n/4), y f(n/6) se comparte entre f(2) y f(3)). Al guardarlos como una sola variable en el padre, el cálculo del subárbol se realiza una sola vez.

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

El código no se extiende fácilmente a una función de memoización general (al menos, no sabría cómo hacerlo), y realmente tienes que pensar cómo se superponen los subproblemas, pero la estrategia debería funcionar para múltiples parámetros generales no enteros. (Lo pensé para dos parámetros de cadena.)

El memo se descarta después de cada cálculo. (De nuevo, estaba pensando en dos parámetros de cadena.)

No se si esto es más eficiente que las otras respuestas. Cada búsqueda es técnicamente solo uno o dos pasos ("Mire a su hijo o al hijo de su hijo"), pero puede haber mucho uso de memoria adicional.

Editar: Esta solución aún no es correcto. El intercambio está incompleto.

Editar: Debería estar compartiendo subchildren correctamente ahora, pero me di cuenta de que este problema tiene una gran cantidad de compartir no trivial: n/2/2/2 y n/3/3 podría ser el mismo. El problema no encaja bien con mi estrategia.

 2
Author: leewz,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-04-14 02:10:32