2010-09-04 10 views
19

Ogni volta che viene chiamata una funzione, se il risultato per un determinato set di valori di argomento non è ancora memoized, mi piacerebbe inserire il risultato in una tabella in memoria. Una colonna è pensata per memorizzare un risultato, altri per memorizzare i valori degli argomenti.Quale tipo utilizzare per memorizzare una tabella di dati mutabili in memoria in Scala?

Come implementarlo al meglio? Gli argomenti sono di vario tipo, inclusi alcuni enumerati.

In C# generalmente uso DataTable. Esiste un equivalente in Scala?

+2

Se si cerca il Web per "Scala Funzione Memoizzazione" troverete diversi trattamenti di questo argomento. –

risposta

25

È possibile utilizzare uno mutable.Map[TupleN[A1, A2, ..., AN], R] oppure, se la memoria è un problema, una WeakHashMap [1]. Le definizioni seguenti (create sul codice di memoizzazione da michid's blog) consentono di memorizzare facilmente funzioni con più argomenti. Ad esempio:

import Memoize._ 

def reallySlowFn(i: Int, s: String): Int = { 
    Thread.sleep(3000) 
    i + s.length 
} 

val memoizedSlowFn = memoize(reallySlowFn _) 
memoizedSlowFn(1, "abc") // returns 4 after about 3 seconds 
memoizedSlowFn(1, "abc") // returns 4 almost instantly 

Definizioni:

/** 
* A memoized unary function. 
* 
* @param f A unary function to memoize 
* @param [T] the argument type 
* @param [R] the return type 
*/ 
class Memoize1[-T, +R](f: T => R) extends (T => R) { 
    import scala.collection.mutable 
    // map that stores (argument, result) pairs 
    private[this] val vals = mutable.Map.empty[T, R] 

    // Given an argument x, 
    // If vals contains x return vals(x). 
    // Otherwise, update vals so that vals(x) == f(x) and return f(x). 
    def apply(x: T): R = vals getOrElseUpdate (x, f(x)) 
} 

object Memoize { 
    /** 
    * Memoize a unary (single-argument) function. 
    * 
    * @param f the unary function to memoize 
    */ 
    def memoize[T, R](f: T => R): (T => R) = new Memoize1(f) 

    /** 
    * Memoize a binary (two-argument) function. 
    * 
    * @param f the binary function to memoize 
    * 
    * This works by turning a function that takes two arguments of type 
    * T1 and T2 into a function that takes a single argument of type 
    * (T1, T2), memoizing that "tupled" function, then "untupling" the 
    * memoized function. 
    */ 
    def memoize[T1, T2, R](f: (T1, T2) => R): ((T1, T2) => R) = 
     Function.untupled(memoize(f.tupled)) 

    /** 
    * Memoize a ternary (three-argument) function. 
    * 
    * @param f the ternary function to memoize 
    */ 
    def memoize[T1, T2, T3, R](f: (T1, T2, T3) => R): ((T1, T2, T3) => R) = 
     Function.untupled(memoize(f.tupled)) 

    // ... more memoize methods for higher-arity functions ... 

    /** 
    * Fixed-point combinator (for memoizing recursive functions). 
    */ 
    def Y[T, R](f: (T => R) => T => R): (T => R) = { 
     lazy val yf: (T => R) = memoize(f(yf)(_)) 
     yf 
    } 
} 

Il combinatore-punto fisso (Memoize.Y) rende possibile Memoize funzioni ricorsive:

val fib: BigInt => BigInt = {       
    def fibRec(f: BigInt => BigInt)(n: BigInt): BigInt = { 
     if (n == 0) 1 
     else if (n == 1) 1 
     else (f(n-1) + f(n-2))       
    }              
    Memoize.Y(fibRec) 
} 

[1] WeakHashMap non funziona bene come cache. Vedi http://www.codeinstructions.com/2008/09/weakhashmap-is-not-cache-understanding.html e this related question.

+0

Si noti che l'implementazione di cui sopra non è thread-safe, quindi se è necessario memorizzare un po 'di calcoli da più thread, questo si interromperà potenzialmente. Per cambiarlo in modo da essere thread-safe, basta: privato [questo] val vals = nuovo HashMap [T, R] con SynchronizedMap [T, R] –

+1

C'è un altro modo per la memoizzazione delle funzioni ricorsive: http: //stackoverflow.com/a/25129872/2073130, e non richiede l'uso del combinatore Y o, quindi, la formulazione di una forma non ricorsiva, che potrebbe essere scoraggiante per le funzioni ricorsive con più di un parametro. In realtà entrambi i metodi si basano sul supporto di Scala per la ricorsione della funzione, cioè quando si usa il combinatore Y 'yf' sta chiamando' yf', mentre nella variante del wrick collegato, si chiamerebbe una funzione memoizzata. – lcn

10

La versione suggerita da anovstrup utilizzando una mappa mutabile è fondamentalmente la stessa di C#, e quindi facile da usare.

Ma se lo desideri puoi anche usare uno stile più funzionale. Usa mappe immutabili, che agiscono come una specie di accumulatore. Avere Tuple (invece di Int nell'esempio) come chiavi funziona esattamente come nel caso mutabile.

def fib(n:Int) = fibM(n, Map(0->1, 1->1))._1 

def fibM(n:Int, m:Map[Int,Int]):(Int,Map[Int,Int]) = m.get(n) match { 
    case Some(f) => (f, m) 
    case None => val (f_1,m1) = fibM(n-1,m) 
       val (f_2,m2) = fibM(n-2,m1) 
       val f = f_1+f_2 
       (f, m2 + (n -> f)) 
} 

Naturalmente questo è un po 'più complicato, ma una tecnica utile sapere (si noti che il codice sopra mira per chiarezza, non per la velocità).

3

Essendo un principiante in questo argomento, ho potuto comprendere appieno nessuno degli esempi forniti (ma vorrei comunque ringraziare). Rispettosamente, presenterei la mia soluzione per il caso che qualcuno viene qui con lo stesso livello e lo stesso problema. Penso che il mio codice possa essere chiaro per chiunque abbia solo the very-very basic Scala knowledge.

 


def MyFunction(dt : DateTime, param : Int) : Double 
{ 
    val argsTuple = (dt, param) 
    if(Memo.contains(argsTuple)) Memo(argsTuple) else Memoize(dt, param, MyRawFunction(dt, param)) 
} 

def MyRawFunction(dt : DateTime, param : Int) : Double 
{ 
    1.0 // A heavy calculation/querying here 
} 

def Memoize(dt : DateTime, param : Int, result : Double) : Double 
{ 
    Memo += (dt, param) -> result 
    result 
} 

val Memo = new scala.collection.mutable.HashMap[(DateTime, Int), Double] 

 

Funziona perfettamente. Apprezzerei la critica se mi fosse sfuggito qualcosa.

+1

Ho aggiunto alcuni commenti alla mia soluzione che si spera di chiarire per voi. Il vantaggio dell'approccio che ho delineato è che consente di memoizzare * qualsiasi * funzione (ok, ci sono alcune avvertenze, ma * molte funzioni *). Un po 'come la parola chiave memoize che hai postato in una domanda correlata. –

+2

L'unico aspetto che probabilmente rimane mistificante è il combinatore a virgola fissa: per questo ti incoraggio a leggere il blog di michid, a bere un sacco di caffè e magari a diventare amichevole con alcuni testi di programmazione funzionale. La buona notizia è che ne hai solo bisogno se stai memoizing una funzione ricorsiva. –

1

Quando si utilizza la mappa mutabile per la memoizzazione, si tenga presente che ciò causerebbe problemi tipici di concorrenza, ad es. fare un get quando una scrittura non è ancora stata completata. Tuttavia, il tentativo di memoizzazione thread-safe suggerisce di farlo è di poco valore se non nessuno.

Il seguente codice thread-safe crea una funzione memoized fibonacci, avvia un paio di thread (denominati da 'a' a 'd') che effettuano chiamate su di esso. Prova il codice un paio di volte (in REPL), si può facilmente vedere che f(2) set viene stampato più di una volta. Ciò significa che un thread A ha avviato il calcolo di f(2) ma Thread B non ne ha assolutamente idea e avvia la propria copia del calcolo. Tale ignoranza è così pervasiva nella fase di costruzione della cache, poiché tutti i thread non vedono alcuna soluzione secondaria stabilita e entrerebbero nella clausola else.

object ScalaMemoizationMultithread { 

    // do not use case class as there is a mutable member here 
    class Memo[-T, +R](f: T => R) extends (T => R) { 
    // don't even know what would happen if immutable.Map used in a multithreading context 
    private[this] val cache = new java.util.concurrent.ConcurrentHashMap[T, R] 
    def apply(x: T): R = 
     // no synchronized needed as there is no removal during memoization 
     if (cache containsKey x) { 
     Console.println(Thread.currentThread().getName() + ": f(" + x + ") get") 
     cache.get(x) 
     } else { 
     val res = f(x) 
     Console.println(Thread.currentThread().getName() + ": f(" + x + ") set") 
     cache.putIfAbsent(x, res) // atomic 
     res 
     } 
    } 

    object Memo { 
    def apply[T, R](f: T => R): T => R = new Memo(f) 

    def Y[T, R](F: (T => R) => T => R): T => R = { 
     lazy val yf: T => R = Memo(F(yf)(_)) 
     yf 
    } 
    } 

    val fibonacci: Int => BigInt = { 
    def fiboF(f: Int => BigInt)(n: Int): BigInt = { 
     if (n <= 0) 1 
     else if (n == 1) 1 
     else f(n - 1) + f(n - 2) 
    } 

    Memo.Y(fiboF) 
    } 

    def main(args: Array[String]) = { 
    ('a' to 'd').foreach(ch => 
     new Thread(new Runnable() { 
     def run() { 
      import scala.util.Random 
      val rand = new Random 
      (1 to 2).foreach(_ => { 
      Thread.currentThread().setName("Thread " + ch) 
      fibonacci(5) 
      }) 
     } 
     }).start) 
    } 
} 
0

Oltre alla risposta di Landei, voglio anche suggerire il basso verso l'alto (non Memoizzazione) modo di fare DP a Scala è possibile, e l'idea di base è quella di utilizzare foldLeft (s).

Esempio per il calcolo di numeri di Fibonacci

def fibo(n: Int) = (1 to n).foldLeft((0, 1)) { 
    (acc, i) => (acc._2, acc._1 + acc._2) 
    }._1 

Esempio per lungo aumentare sottosequenza

def longestIncrSubseq[T](xs: List[T])(implicit ord: Ordering[T]) = { 
    xs.foldLeft(List[(Int, List[T])]()) { 
    (memo, x) => 
     if (memo.isEmpty) List((1, List(x))) 
     else { 
     val resultIfEndsAtCurr = (memo, xs).zipped map { 
      (tp, y) => 
      val len = tp._1 
      val seq = tp._2 
      if (ord.lteq(y, x)) { // current is greater than the previous end 
       (len + 1, x :: seq) // reversely recorded to avoid O(n) 
      } else { 
       (1, List(x)) // start over 
      } 
     } 
     memo :+ resultIfEndsAtCurr.maxBy(_._1) 
     } 
    }.maxBy(_._1)._2.reverse 
} 
Problemi correlati