2012-05-21 19 views
7

Sono un po 'imbarazzato ad ammetterlo, ma mi sembra di essere piuttosto sconcertato da quello che dovrebbe essere un semplice problema di programmazione. Sto costruendo un'implementazione di un albero decisionale e sto usando la ricorsione per prendere una lista di campioni etichettati, dividere ricorsivamente l'elenco a metà e trasformarlo in un albero.Codifica ricorsiva della creazione di alberi con loop while + stack

Sfortunatamente, con alberi profondi mi imbatto in errori di overflow dello stack (ha!), Quindi il mio primo pensiero è stato quello di utilizzare le continuazioni per trasformarlo in ricorsione di coda. Sfortunatamente Scala non supporta quel tipo di TCO, quindi l'unica soluzione è usare un trampolino. Un trampolino sembra un po 'inefficiente e speravo che ci sarebbe stata una semplice soluzione imperativa basata su stack per questo problema, ma ho molti problemi a trovarlo.

La versione ricorsiva sembra un po 'come (semplificato):

private def trainTree(samples: Seq[Sample], usedFeatures: Set[Int]): DTree = { 
    if (shouldStop(samples)) { 
    DTLeaf(makeProportions(samples)) 
    } else { 
    val featureIdx = getSplittingFeature(samples, usedFeatures) 
    val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _)) 
    DTBranch(
     trainTree(statsWithFeature, usedFeatures + featureIdx), 
     trainTree(statsWithoutFeature, usedFeatures + featureIdx), 
     featureIdx) 
    } 
} 

Quindi, fondamentalmente sto ricorsivamente suddividere la lista in due secondo alcuni caratteristica dei dati, e passando attraverso un elenco di funzioni utilizzate in modo Non ripeto: tutto questo viene gestito nella funzione "getSplittingFeature" in modo che possiamo ignorarlo. Il codice è davvero semplice! Tuttavia, ho difficoltà a trovare una soluzione basata su stack che non usi solo chiusure e diventi effettivamente un trampolino. So che dovremo almeno tenere un po 'di "frame" di argomenti nello stack, ma vorrei evitare le chiamate di chiusura.

Ho capito che dovrei scrivere esplicitamente ciò che il callstack e il contatore del programma gestiscono implicitamente nella soluzione ricorsiva, ma ho difficoltà a farlo senza continuazioni. A questo punto non si parla nemmeno di efficienza, sono solo curioso. Quindi, per favore, non c'è bisogno di ricordarmi che l'ottimizzazione prematura è la radice di tutto il male e la soluzione basata sul trampolino probabilmente funzionerà perfettamente. So che probabilmente lo farò - questo è fondamentalmente un enigma per il proprio bene.

Qualcuno può dirmi qual è la soluzione canonica basata su loop e stack per questo genere di cose?

AGGIORNAMENTO: Sulla base dell'eccellente soluzione di Thipor Kong, ho codificato l'implementazione di un algoritmo while-loops/stack/hashtable che dovrebbe essere una traduzione diretta della versione ricorsiva. Questo è esattamente quello che stavo cercando:

AGGIORNAMENTO FINALE: ho usato indici interi sequenziali, oltre a rimettere tutto in array invece di mappe per prestazioni, aggiunto supporto maxDepth e infine avere una soluzione con lo stesso prestazioni della versione ricorsiva (non sono sicuro circa l'utilizzo della memoria, ma direi meno):

private def trainTreeNoMaxDepth(startingSamples: Seq[Sample], startingMaxDepth: Int): DTree = { 
    // Use arraybuffer as dense mutable int-indexed map - no IndexOutOfBoundsException, just expand to fit 
    type DenseIntMap[T] = ArrayBuffer[T] 
    def updateIntMap[@specialized T](ab: DenseIntMap[T], idx: Int, item: T, dfault: T = null.asInstanceOf[T]) = { 
    if (ab.length <= idx) {ab.insertAll(ab.length, Iterable.fill(idx - ab.length + 1)(dfault)) } 
    ab.update(idx, item) 
    } 
    var currentChildId = 0 // get childIdx or create one if it's not there already 
    def child(childMap: DenseIntMap[Int], heapIdx: Int) = 
    if (childMap.length > heapIdx && childMap(heapIdx) != -1) childMap(heapIdx) 
    else {currentChildId += 1; updateIntMap(childMap, heapIdx, currentChildId, -1); currentChildId } 
    // go down 
    val leftChildren, rightChildren = new DenseIntMap[Int]() // heapIdx -> childHeapIdx 
    val todo = Stack((startingSamples, Set.empty[Int], startingMaxDepth, 0)) // samples, usedFeatures, maxDepth, heapIdx 
    val branches = new Stack[(Int, Int)]() // heapIdx, featureIdx 
    val nodes = new DenseIntMap[DTree]() // heapIdx -> node 
    while (!todo.isEmpty) { 
    val (samples, usedFeatures, maxDepth, heapIdx) = todo.pop() 
    if (shouldStop(samples) || maxDepth == 0) { 
     updateIntMap(nodes, heapIdx, DTLeaf(makeProportions(samples))) 
    } else { 
     val featureIdx = getSplittingFeature(samples, usedFeatures) 
     val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _)) 
     todo.push((statsWithFeature, usedFeatures + featureIdx, maxDepth - 1, child(leftChildren, heapIdx))) 
     todo.push((statsWithoutFeature, usedFeatures + featureIdx, maxDepth - 1, child(rightChildren, heapIdx))) 
     branches.push((heapIdx, featureIdx)) 
    } 
    } 
    // go up 
    while (!branches.isEmpty) { 
    val (heapIdx, featureIdx) = branches.pop() 
    updateIntMap(nodes, heapIdx, DTBranch(nodes(child(leftChildren, heapIdx)), nodes(child(rightChildren, heapIdx)), featureIdx)) 
    } 
    nodes(0) 
} 
+0

L'offload non è un'implementazione basata su stack (dove lo stack è nell'heap) concettualmente come il trampolino? – ron

+0

Un po ', ma il trampolino significa che si sta mantenendo uno stack pieno di chiusure, dove spero che ci sia una soluzione che utilizza solo uno stack pieno di dati. Forse dati etichettati come StepOne (a, b, c), StepTwo (a, b, c) o stack multipli o qualcosa del genere, ma nessuna chiamata di funzione sarebbe coinvolta. – lvilnis

+0

Apporta un'altra modifica al mio codice. Lo spazio dei nomi degli id ​​dei nodi viene utilizzato in modo più economico ed è possibile collegare il proprio tipo di id dei nodi (o BigInt, se lo si desidera). –

risposta

3

Basta memorizzare l'albero binario in un array, come descritto a Wikipedia: Per il nodo i, il figlio sinistro va in 2*i+1 e il bambino giusto in 2*i+2. Quando fai "down", tieni una collezione di cose da fare, che devono ancora essere divise per raggiungere una foglia. Una volta che hai solo le foglie, per costruire i nodi decisionali:

Aggiornamento: Una versione ripulita, che supporta anche le funzioni memorizzate nei rami (tipo parametro B) e che è più funzionale/completamente puro e che supporta alberi sparsi con una mappa come suggerito da Ron.

Aggiornamento2-3: Utilizzare in modo economico lo spazio dei nomi per gli id ​​dei nodi e il tipo di ID astratto per consentire gli alberi di grandi dimensioni. Prendi gli id ​​dei nodi da Stream.

sealed trait DTree[A, B] 
case class DTLeaf[A, B](a: A, b: B) extends DTree[A, B] 
case class DTBranch[A, B](left: DTree[A, B], right: DTree[A, B], b: B) extends DTree[A, B] 

def mktree[A, B, Id](a: A, b: B, split: (A, B) => Option[(A, A, B)], ids: Stream[Id]) = { 
    @tailrec 
    def goDown(todo: Seq[(A, B, Id)], branches: Seq[(Id, B, Id, Id)], leafs: Map[Id, DTree[A, B]], ids: Stream[Id]): (Seq[(Id, B, Id, Id)], Map[Id, DTree[A, B]]) = 
    todo match { 
     case Nil => (branches, leafs) 
     case (a, b, id) :: rest => 
     split(a, b) match { 
      case None => 
      goDown(rest, branches, leafs + (id -> DTLeaf(a, b)), ids) 
      case Some((left, right, b2)) => 
      val leftId #:: rightId #:: idRest = ids 
      goDown((right, b2, rightId) +: (left, b2, leftId) +: rest, (id, b2, leftId, rightId) +: branches, leafs, idRest) 
     } 
    } 

    @tailrec 
    def goUp[A, B](branches: Seq[(Id, B, Id, Id)], nodes: Map[Id, DTree[A, B]]): Map[Id, DTree[A, B]] = 
    branches match { 
     case Nil => nodes 
     case (id, b, leftId, rightId) :: rest => 
     goUp(rest, nodes + (id -> DTBranch(nodes(leftId), nodes(rightId), b))) 
    } 

    val rootId #:: restIds = ids 
    val (branches, leafs) = goDown(Seq((a, b, rootId)), Seq(), Map(), restIds) 
    goUp(branches, leafs)(rootId) 
} 

// try it out 

def split(xs: Seq[Int], b: Int) = 
    if (xs.size > 1) { 
    val (left, right) = xs.splitAt(xs.size/2) 
    Some((left, right, b + 1)) 
    } else { 
    None 
    } 

val tree = mktree(0 to 1000, 0, split _, Stream.from(0)) 
println(tree) 
+0

Che dire del fatto che ogni DTBranch ha bisogno di un "featureIndex"? Ciò rende un po 'più complicato dal momento che per trasformare tutte le foglie in rami abbiamo bisogno del loro featureIndex, e quindi per combinare insieme questi rami abbiamo bisogno dei loro featureIndex e così via. Penso che questa sia l'idea giusta, quindi la giocherò. – lvilnis

+0

Metti le featureIndice nell'heap quando vai giù (invece del None), per averlo a disposizione per creare il DTBranch, quando salirai di nuovo. –

+0

È fantastico! Lo proverò e segnerò la tua come risposta entro un'ora. – lvilnis