2013-06-15 12 views
16

goworker tail-recursive loop pattern sembra funzionare molto bene per scrivere codice puro. Quale sarebbe il modo equivalente di scrivere quel tipo di loop per la monade ST? Più specificamente, voglio evitare la nuova allocazione dell'heap nelle iterazioni del ciclo. La mia ipotesi è che sia lo CPS transformation o lo fixST di riscrivere il codice in modo tale che tutti i valori che stanno cambiando attraverso il ciclo siano passati attraverso ogni iterazione, rendendo così disponibili posizioni di registro (o stack in caso di spill) per quei valori attraverso iterazione . Ho un esempio semplificato di seguito (non provare a eseguirlo - sarà probabilmente in crash con errore di segmentazione!) Che coinvolge una funzione chiamata findSnakes che ha un modello operaio go ma i valori di stato che cambiano non sono passati attraverso argomenti accumulatori:Scrittura efficiente loop iterativo per ST monad

{-# LANGUAGE BangPatterns #-} 
module Test where 

import Data.Vector.Unboxed.Mutable as MU 
import Data.Vector.Unboxed as U hiding (mapM_) 
import Control.Monad.ST as ST 
import Control.Monad.Primitive (PrimState) 
import Control.Monad as CM (when,forM_) 
import Data.Int 

type MVI1 s = MVector (PrimState (ST s)) Int 

-- function to find previous y 
findYP :: MVI1 s -> Int -> Int -> ST s Int 
findYP fp k offset = do 
       y0 <- MU.unsafeRead fp (k+offset-1) >>= \x -> return $ 1+x 
       y1 <- MU.unsafeRead fp (k+offset+1) 
       if y0 > y1 then return y0 
       else return y1 
{-#INLINE findYP #-} 

findSnakes :: Vector Int32 -> MVI1 s -> Int -> Int -> (Int -> Int -> Int) -> ST s() 
findSnakes a fp !k !ct !op = go 0 k 
    where 
      offset=1+U.length a 
      go x k' 
      | x < ct = do 
       yp <- findYP fp k' offset 
       MU.unsafeWrite fp (k'+offset) (yp + k') 
       go (x+1) (op k' 1) 
      | otherwise = return() 
{-#INLINE findSnakes #-} 

Guardando cmm output in ghc 7.6.1 (con la mia limitata conoscenza della cmm - si prega di correggermi se ho sbagliato), vedo questo tipo di flusso di chiamate, con passante in s1tb_info (che provoca l'allocazione di heap e heap di controllo in ogni iterazione):

findSnakes_info -> a1_r1qd_info -> $wa_r1qc_info (new stack allocation, SpLim check) 
-> s1sy_info -> s1sj_info: if arg > 1 then s1w8_info else R1 (can't figure out 
what that register points to) 

-- I am guessing this one below is for go loop 
s1w8_info -> s1w7_info (big heap allocation, HpLim check) -> s1tb_info: if arg >= 1 
then s1td_info else R1 

s1td_info (big heap allocation, HpLim check) -> if arg >= 1 then s1tb_info 
(a loop) else s1tb_info (after executing a different block of code) 

La mia ipotesi è che il controllo del modulo arg >= 1 nel codice cmm determina se il ciclo go è terminato o meno. Se ciò è corretto, sembra che a meno che il ciclo go venga riscritto per passare yp attraverso il ciclo, l'allocazione dell'heap avverrà attraverso il loop per i nuovi valori (suppongo che l'allocazione dell'heap sia dovuta a yp). Quale sarebbe un modo efficace per scrivere il ciclo go nell'esempio sopra? Immagino che yp debba essere passato come argomento nel ciclo go o in modo equivalente tramite la trasformazione fixST o CPS. Non riesco a pensare a un buon modo per riscrivere il ciclo go sopra per rimuovere le allocazioni dell'heap e apprezzare l'aiuto con esso.

risposta

3

Ho riscritto le funzioni per evitare qualsiasi ricorsione esplicita e rimosso alcune operazioni ridondanti che calcolano gli offset. Questo compila in un nucleo molto più bello delle tue funzioni originali.

Il core, a proposito, è probabilmente il modo migliore per analizzare il codice compilato per questo tipo di profilazione. Utilizzare ghc -ddump-simpl per visualizzare l'output generato nucleo, o strumenti come ghc-core

import Control.Monad.Primitive                    
import Control.Monad.ST                      
import Data.Int                        
import qualified Data.Vector.Unboxed.Mutable as M                
import qualified Data.Vector.Unboxed as U                  

type MVI1 s = M.MVector (PrimState (ST s)) Int                

findYP :: MVI1 s -> Int -> ST s Int                                      
findYP fp offset = do                      
    y0 <- M.unsafeRead fp (offset+0)                  
    y1 <- M.unsafeRead fp (offset+2)                  
    return $ max (y0 + 1) y1                     

findSnakes :: U.Vector Int32 -> MVI1 s -> Int -> Int -> (Int -> Int -> Int) -> ST s()                           
findSnakes a fp k0 ct op = U.mapM_ writeAt $ U.iterateN ct (`op` 1) k0          
    where writeAt k = do  
       let offset = U.length a + k                 
       yp <- findYP fp offset                   
       M.unsafeWrite fp (offset + 1) (yp + k) 

      -- or inline findYP manually 
      writeAt k = do 
      let offset = U.length a + k 
      y0 <- M.unsafeRead fp (offset + 0) 
      y1 <- M.unsafeRead fp (offset + 2) 
      M.unsafeWrite fp (offset + 1) (k + max (y0 + 1) y1) 

Inoltre, si passa un U.Vector Int32 a findSnakes, solo per calcolare la sua lunghezza e non usare mai a nuovo. Perché non passare direttamente la lunghezza?