Como extrair as regras de decisão da árvore de decisão scikit-learn?


157

Posso extrair as regras de decisão subjacentes (ou 'caminhos de decisão') de uma árvore treinada em uma árvore de decisão como uma lista textual?

Algo como:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

Obrigado pela ajuda.



Você já encontrou uma resposta para esse problema? Eu tenho que exportar as regras da árvore de decisão em um formato de etapa de dados SAS que é quase exatamente o que você listou.
Zelazny7

1
Você pode usar o pacote sklearn-porter para exportar e transpilar árvores de decisão (também floresta aleatória e árvores potencializadas) para C, Java, JavaScript e outras.
Darius

Você pode verificar este link- kdnuggets.com/2017/05/…
yogesh agrawal

Respostas:


139

Eu acredito que esta resposta é mais correta do que as outras respostas aqui:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

Isso imprime uma função Python válida. Aqui está um exemplo de saída para uma árvore que está tentando retornar sua entrada, um número entre 0 e 10.

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

Aqui estão alguns obstáculos que eu vejo em outras respostas:

  1. Usar tree_.threshold == -2para decidir se um nó é uma folha não é uma boa ideia. E se for um nó de decisão real com um limite de -2? Em vez disso, você deve olhar para tree.featureou tree.children_*.
  2. A linha features = [feature_names[i] for i in tree_.feature]trava com a minha versão do sklearn, porque alguns valores de tree.tree_.featuresão -2 (especificamente para nós de folha).
  3. Não há necessidade de ter várias instruções if na função recursiva, apenas uma está correta.

1
Este código funciona muito bem para mim. No entanto, eu tenho mais de 500 feature_names, então o código de saída é quase impossível para um ser humano entender. Existe uma maneira de me deixar inserir apenas os feature_names dos quais estou curioso na função?
user3768495

1
Eu concordo com o comentário anterior. IIUC, print "{}return {}".format(indent, tree_.value[node])deve ser alterado para print "{}return {}".format(indent, np.argmax(tree_.value[node][0]))para a função retornar o índice de classe.
soupault 19/10/19

1
@ paulkernfeld Ah, sim, vejo que você pode fazer um loop RandomForestClassifier.estimators_, mas não consegui descobrir como combinar os resultados dos estimadores.
Nathan Lloyd

6
Não consegui fazer isso funcionar no python 3, os bits _tree não parecem funcionar e o TREE_UNDEFINED não foi definido. Esse link me ajudou. Embora o código exportado não seja diretamente executável em python, é c-like e muito fácil de traduzir para outros idiomas: web.archive.org/web/20171005203850/http://www.kdnuggets.com/…
Josiah

1
@ Josiah, adicione () às instruções print para fazê-lo funcionar em python3. por exemplo print "bla"=>print("bla")
Nir

48

Criei minha própria função para extrair as regras das árvores de decisão criadas pelo sklearn:

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

Essa função começa primeiro com os nós (identificados por -1 nas matrizes filhas) e depois localiza recursivamente os pais. Eu chamo isso de 'linhagem' de um nó. No caminho, pego os valores necessários para criar a lógica SAS if / then / else:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

Os conjuntos de tuplas abaixo contêm tudo o que preciso para criar instruções if / then / else do SAS. Não gosto de usar doblocos no SAS, e é por isso que crio uma lógica que descreve o caminho inteiro de um nó. O número inteiro único após as tuplas é o ID do nó do terminal em um caminho. Todas as tuplas anteriores se combinam para criar esse nó.

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

Saída GraphViz da árvore de exemplo


É este tipo de árvore é correto porque col1 está vindo novamente é col1 <= 0,50000 e um col1 <= 2.5000, se sim, isso é qualquer tipo de recursão whish é usado na biblioteca
Jayant singh

o ramo direito teria registros entre (0.5, 2.5]. As árvores são feitas com particionamento recursivo. Não há nada impedindo que uma variável seja selecionada várias vezes.
precisa saber é o seguinte

tudo bem você pode explicar a parte recursão que acontece Xactly porque eu tê-lo usado no meu código e resultado similar é visto
Jayant singh

38

Modifiquei o código enviado por Zelazny7 para imprimir algum pseudocódigo:

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

se você chamar get_code(dt, df.columns)o mesmo exemplo, obterá:

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}

1
Você pode dizer o que exatamente [[1. 0.]] na declaração de retorno significa na saída acima. Eu não sou um cara Python, mas estou trabalhando no mesmo tipo de coisa. Portanto, será bom para mim, se você provar alguns detalhes, para que seja mais fácil para mim.
Subhradip Bose

1
@ user3156186 Isso significa que há um objeto na classe '0' e zero objetos na classe '1'
Daniele

1
@ Daniel, você sabe como as aulas são ordenadas? Eu acho que alfanumérico, mas não encontrei confirmação em lugar algum.
IanS 4/15

Obrigado! Para o cenário caso extremo em que o valor do limiar é realmente -2, nós pode precisar mudar (threshold[node] != -2)para ( left[node] != -1)(semelhante ao método a seguir para obter ids de nós filhos)
tlingf

@ Daniel, alguma idéia de como tornar sua função "get_code" "retornar" um valor e não "imprimi-lo", porque preciso enviá-lo para outra função?
RoyaumeIX

17

O Scikit learn introduziu um novo e delicioso método chamado export_textna versão 0.21 (maio de 2019) para extrair as regras de uma árvore. Documentação aqui . Não é mais necessário criar uma função personalizada.

Depois de ajustar seu modelo, você só precisa de duas linhas de código. Primeiro, importe export_text:

from sklearn.tree.export import export_text

Segundo, crie um objeto que conterá suas regras. Para tornar as regras mais legíveis, use o feature_namesargumento e passe uma lista dos nomes dos recursos. Por exemplo, se seu modelo for chamado modele seus recursos forem nomeados em um dataframe chamado X_train, você poderá criar um objeto chamado tree_rules:

tree_rules = export_text(model, feature_names=list(X_train))

Em seguida, basta imprimir ou salvar tree_rules. Sua saída ficará assim:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1

14

Há um novo DecisionTreeClassifiermétodo decision_path, na versão 0.18.0 . Os desenvolvedores fornecem uma explicação extensa (bem documentada) .

A primeira seção do código na explicação passo a passo que imprime a estrutura da árvore parece estar OK. No entanto, modifiquei o código na segunda seção para interrogar uma amostra. Minhas alterações denotadas com# <--

Editar As alterações marcadas # <--no código abaixo foram atualizadas no link passo a passo depois que os erros foram apontados nas solicitações de recebimento nº 8653 e nº 10951 . É muito mais fácil acompanhar agora.

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0: 
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

Altere sample_idpara ver os caminhos de decisão para outras amostras. Eu não perguntei aos desenvolvedores sobre essas mudanças, apenas pareceu mais intuitivo ao trabalhar com o exemplo.


você meu amigo é uma lenda! Alguma idéia de como plotar a árvore de decisão para essa amostra específica? muita ajuda é apreciada

1
Obrigado Victor, provavelmente é melhor fazer isso como uma pergunta separada, pois os requisitos de plotagem podem ser específicos às necessidades do usuário. Você provavelmente obterá uma boa resposta se fornecer uma idéia de como deseja que a saída seja.
22418 Kevin

hey kevin, eu criei a pergunta stackoverflow.com/questions/48888893/…

você seria tão amável para dar uma olhada em: stackoverflow.com/questions/52654280/...
Alexander Chervov

Você pode explicar a parte chamada node_index, não obtendo essa parte. O que isso faz?
Anindya Sankar Dey 30/03

12
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()

Você pode ver uma árvore do dígrafo. Então, clf.tree_.featuree clf.tree_.valuesão a matriz de nós que divide o recurso e a matriz de valores de nós, respectivamente. Você pode consultar mais detalhes nesta fonte do github .


1
Sim, eu sei desenhar a árvore - mas preciso da versão mais textual - das regras. algo como: orange.biolab.si/docs/latest/reference/rst/…
Dror Hilman

4

Só porque todos foram muito prestativos, adicionarei uma modificação às belas soluções de Zelazny7 e Daniele. Este é para python 2.7, com abas para torná-lo mais legível:

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                    print '\t' * tabdepth,
                    print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "} else {"
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "}"
            else:
                    print '\t' * tabdepth,
                    print "return " + str(value[node])

    recurse(left, right, threshold, features, 0)

3

Os códigos abaixo são minha abordagem no anaconda python 2.7 mais um nome de pacote "pydot-ng" para criar um arquivo PDF com regras de decisão. Espero que seja útil.

from sklearn import tree

clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)

feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)

def output_pdf(clf_, name):
    from sklearn import tree
    from sklearn.externals.six import StringIO
    import pydot_ng as pydot
    dot_data = StringIO()
    tree.export_graphviz(clf_, out_file=dot_data,
                         feature_names=feature_names,
                         class_names=class_name,
                         filled=True, rounded=True,
                         special_characters=True,
                          node_ids=1,)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("%s.pdf"%name)

output_pdf(clf_, name='filename%s'%n)

um gráfico de árvore mostra aqui


3

Eu já passei por isso, mas eu precisava que as regras fossem escritas neste formato

if A>0.4 then if B<0.2 then if C>0.8 then class='X' 

Então adaptei a resposta de @paulkernfeld (obrigado) que você pode personalizar de acordo com sua necessidade

def tree_to_code(tree, feature_names, Y):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    pathto=dict()

    global k
    k = 0
    def recurse(node, depth, parent):
        global k
        indent = "  " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            s= "{} <= {} ".format( name, threshold, node )
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s

            recurse(tree_.children_left[node], depth + 1, node)
            s="{} > {}".format( name, threshold)
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s
            recurse(tree_.children_right[node], depth + 1, node)
        else:
            k=k+1
            print(k,')',pathto[parent], tree_.value[node])
    recurse(0, 1, 0)

3

Aqui está uma maneira de converter a árvore inteira em uma única expressão python (não necessariamente legível por humanos) usando a biblioteca SKompiler :

from skompiler import skompile
skompile(dtree.predict).to('python/code')

3

Isso se baseia na resposta de @paulkernfeld. Se você possui um dataframe X com seus recursos e um dataframe de destino y com suas ressonâncias e deseja ter uma idéia de qual valor y terminou em qual nó (e também formiga para plotá-lo adequadamente), você pode fazer o seguinte:

    def tree_to_code(tree, feature_names):
        from sklearn.tree import _tree
        codelines = []
        codelines.append('def get_cat(X_tmp):\n')
        codelines.append('   catout = []\n')
        codelines.append('   for codelines in range(0,X_tmp.shape[0]):\n')
        codelines.append('      Xin = X_tmp.iloc[codelines]\n')
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        #print "def tree({}):".format(", ".join(feature_names))

        def recurse(node, depth):
            indent = "      " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                codelines.append( '{}else:  # if Xin["{}"] > {}\n'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                codelines.append( '{}mycat = {}\n'.format(indent, node))

        recurse(0, 1)
        codelines.append('      catout.append(mycat)\n')
        codelines.append('   return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
        codelines.append('node_ids = get_cat(X)\n')
        return codelines
    mycode = tree_to_code(clf,X.columns.values)

    # now execute the function and obtain the dataframe with all nodes
    exec(''.join(mycode))
    node_ids = [int(x[0]) for x in node_ids.values]
    node_ids2 = pd.DataFrame(node_ids)

    print('make plot')
    import matplotlib.cm as cm
    colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
    #plt.figure(figsize=cm2inch(24, 21))
    for i in list(set(node_ids)):
        plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))  
    mytitle = ['y colored by node']
    plt.title(mytitle ,fontsize=14)
    plt.xlabel('my xlabel')
    plt.ylabel(tagname)
    plt.xticks(rotation=70)       
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
    plt.tight_layout()
    plt.show()
    plt.close 

não é a versão mais elegante, mas faz o trabalho ...


1
Essa é uma boa abordagem quando você deseja retornar as linhas de código em vez de apenas imprimi-las.
Hajar Homayouni

3

Este é o código que você precisa

Modifiquei o código mais curtido para recuar em um notebook jupyter python 3 corretamente

import numpy as np
from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [feature_names[i] 
                    if i != _tree.TREE_UNDEFINED else "undefined!" 
                    for i in tree_.feature]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "    " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, np.argmax(tree_.value[node])))

    recurse(0, 1)

2

Aqui está uma função, imprimindo regras de uma árvore de decisão scikit-learn no python 3 e com deslocamentos para blocos condicionais para tornar a estrutura mais legível:

def print_decision_tree(tree, feature_names=None, offset_unit='    '):
    '''Plots textual representation of rules of a decision tree
    tree: scikit-learn representation of tree
    feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
    offset_unit: a string of offset of the conditional block'''

    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value
    if feature_names is None:
        features  = ['f%d'%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0):
            offset = offset_unit*depth
            if (threshold[node] != -2):
                    print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node],depth+1)
                    print(offset+"} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node],depth+1)
                    print(offset+"}")
            else:
                    print(offset+"return " + str(value[node]))

    recurse(left, right, threshold, features, 0,0)

2

Você também pode torná-lo mais informativo, distinguindo-o a qual classe pertence ou mesmo mencionando seu valor de saída.

def print_decision_tree(tree, feature_names, offset_unit='    '):    
left      = tree.tree_.children_left
right     = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
    features  = ['f%d'%i for i in tree.tree_.feature]
else:
    features  = [feature_names[i] for i in tree.tree_.feature]        

def recurse(left, right, threshold, features, node, depth=0):
        offset = offset_unit*depth
        if (threshold[node] != -2):
                print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                if left[node] != -1:
                        recurse (left, right, threshold, features,left[node],depth+1)
                print(offset+"} else {")
                if right[node] != -1:
                        recurse (left, right, threshold, features,right[node],depth+1)
                print(offset+"}")
        else:
                #print(offset,value[node]) 

                #To remove values from node
                temp=str(value[node])
                mid=len(temp)//2
                tempx=[]
                tempy=[]
                cnt=0
                for i in temp:
                    if cnt<=mid:
                        tempx.append(i)
                        cnt+=1
                    else:
                        tempy.append(i)
                        cnt+=1
                val_yes=[]
                val_no=[]
                res=[]
                for j in tempx:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_no.append(j)
                for j in tempy:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_yes.append(j)
                val_yes = int("".join(map(str, val_yes)))
                val_no = int("".join(map(str, val_no)))

                if val_yes>val_no:
                    print(offset,'\033[1m',"YES")
                    print('\033[0m')
                elif val_no>val_yes:
                    print(offset,'\033[1m',"NO")
                    print('\033[0m')
                else:
                    print(offset,'\033[1m',"Tie")
                    print('\033[0m')

recurse(left, right, threshold, features, 0,0)

insira a descrição da imagem aqui


2

Aqui está minha abordagem para extrair as regras de decisão de uma forma que possa ser usada diretamente no sql, para que os dados possam ser agrupados por nó. (Com base nas abordagens dos pôsteres anteriores.)

O resultado serão CASEcláusulas subseqüentes que podem ser copiadas para uma instrução sql, ex.

SELECT COALESCE(*CASE WHEN <conditions> THEN > <NodeA>*, > *CASE WHEN <conditions> THEN <NodeB>*, > ....)NodeName,* > FROM <table or view>


import numpy as np

import pickle
feature_names=.............
features  = [feature_names[i] for i in range(len(feature_names))]
clf= pickle.loads(trained_model)
impurity=clf.tree_.impurity
importances = clf.feature_importances_
SqlOut=""

#global Conts
global ContsNode
global Path
#Conts=[]#
ContsNode=[]
Path=[]
global Results
Results=[]

def print_decision_tree(tree, feature_names, offset_unit=''    ''):    
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value

    if feature_names is None:
        features  = [''f%d''%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0,ParentNode=0,IsElse=0):
        global Conts
        global ContsNode
        global Path
        global Results
        global LeftParents
        LeftParents=[]
        global RightParents
        RightParents=[]
        for i in range(len(left)): # This is just to tell you how to create a list.
            LeftParents.append(-1)
            RightParents.append(-1)
            ContsNode.append("")
            Path.append("")


        for i in range(len(left)): # i is node
            if (left[i]==-1 and right[i]==-1):      
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not " +ContsNode[RightParents[i]]                     
                Results.append(" case when  " +Path[i]+"  then ''" +"{:4d}".format(i)+ " "+"{:2.2f}".format(impurity[i])+" "+Path[i][0:180]+"''")

            else:       
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not "+ContsNode[RightParents[i]]                      
                if (left[i]!=-1):
                    LeftParents[left[i]]=i
                if (right[i]!=-1):
                    RightParents[right[i]]=i
                ContsNode[i]=   "( "+ features[i] + " <= " + str(threshold[i])   + " ) "

    recurse(left, right, threshold, features, 0,0,0,0)
print_decision_tree(clf,features)
SqlOut=""
for i in range(len(Results)): 
    SqlOut=SqlOut+Results[i]+ " end,"+chr(13)+chr(10)

1

Agora você pode usar export_text.

from sklearn.tree import export_text

r = export_text(loan_tree, feature_names=(list(X_train.columns)))
print(r)

Um exemplo completo de [sklearn] [1]

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
iris = load_iris()
X = iris['data']
y = iris['target']
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(X, y)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)

0

Código de Zelazny7 modificado para buscar SQL na árvore de decisão.

# SQL from decision tree

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]
     le='<='               
     g ='>'
     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'
          lineage.append((parent, split, threshold[parent], features[parent]))
          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)
     print 'case '
     for j,child in enumerate(idx):
        clause=' when '
        for node in recurse(left, right, child):
            if len(str(node))<3:
                continue
            i=node
            if i[1]=='l':  sign=le 
            else: sign=g
            clause=clause+i[3]+sign+str(i[2])+' and '
        clause=clause[:-4]+' then '+str(j)
        print clause
     print 'else 99 end as clusters'

0

Aparentemente, há muito tempo, alguém já decidiu tentar adicionar a seguinte função às funções de exportação de árvore do scikit oficial (que basicamente suportam apenas export_graphviz)

def export_dict(tree, feature_names=None, max_depth=None) :
    """Export a decision tree in dict format.

Aqui está o seu commit completo:

https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

Não sei exatamente o que aconteceu com esse comentário. Mas você também pode tentar usar essa função.

Eu acho que isso justifica uma solicitação de documentação séria para o pessoal do scikit-learn documentar adequadamente a sklearn.tree.TreeAPI, que é a estrutura de árvore subjacente que DecisionTreeClassifierexpõe como seu atributo tree_.


0

Basta usar a função sklearn.tree como esta

from sklearn.tree import export_graphviz
    export_graphviz(tree,
                out_file = "tree.dot",
                feature_names = tree.columns) //or just ["petal length", "petal width"]

E, em seguida, procure na pasta do projeto o arquivo tree.dot , copie TODO o conteúdo e cole-o aqui http://www.webgraphviz.com/ e gere seu gráfico :)


0

Obrigado pela maravilhosa solução de @paulkerfeld. No topo de sua solução, para todos aqueles que querem ter uma versão serializada de árvores, é só usar tree.threshold, tree.children_left, tree.children_right, tree.featuree tree.value. Como as folhas não têm divisões e, portanto, nenhum nome de recurso e filhos, seu espaço reservado em tree.featuree tree.children_***são _tree.TREE_UNDEFINEDe _tree.TREE_LEAF. A cada divisão é atribuído um índice exclusivo por depth first search.
Observe que o tree.valueformato está em forma[n, 1, 1]


0

Aqui está uma função que gera código Python a partir de uma árvore de decisão, convertendo a saída de export_text:

import string
from sklearn.tree import export_text

def export_py_code(tree, feature_names, max_depth=100, spacing=4):
    if spacing < 2:
        raise ValueError('spacing must be > 1')

    # Clean up feature names (for correctness)
    nums = string.digits
    alnums = string.ascii_letters + nums
    clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
    features = [clean(x) for x in feature_names]
    features = ['_'+x if x[0] in nums else x for x in features if x]
    if len(set(features)) != len(feature_names):
        raise ValueError('invalid feature names')

    # First: export tree to text
    res = export_text(tree, feature_names=features, 
                        max_depth=max_depth,
                        decimals=6,
                        spacing=spacing-1)

    # Second: generate Python code from the text
    skip, dash = ' '*spacing, '-'*(spacing-1)
    code = 'def decision_tree({}):\n'.format(', '.join(features))
    for line in repr(tree).split('\n'):
        code += skip + "# " + line + '\n'
    for line in res.split('\n'):
        line = line.rstrip().replace('|',' ')
        if '<' in line or '>' in line:
            line, val = line.rsplit(maxsplit=1)
            line = line.replace(' ' + dash, 'if')
            line = '{} {:g}:'.format(line, float(val))
        else:
            line = line.replace(' {} class:'.format(dash), 'return')
        code += skip + line + '\n'

    return code

Uso da amostra:

res = export_py_code(tree, feature_names=names, spacing=4)
print (res)

Saída de amostra:

def decision_tree(f1, f2, f3):
    # DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
    #                        max_features=None, max_leaf_nodes=None,
    #                        min_impurity_decrease=0.0, min_impurity_split=None,
    #                        min_samples_leaf=1, min_samples_split=2,
    #                        min_weight_fraction_leaf=0.0, presort=False,
    #                        random_state=42, splitter='best')
    if f1 <= 12.5:
        if f2 <= 17.5:
            if f1 <= 10.5:
                return 2
            if f1 > 10.5:
                return 3
        if f2 > 17.5:
            if f2 <= 22.5:
                return 1
            if f2 > 22.5:
                return 1
    if f1 > 12.5:
        if f1 <= 17.5:
            if f3 <= 23.5:
                return 2
            if f3 > 23.5:
                return 3
        if f1 > 17.5:
            if f1 <= 25:
                return 1
            if f1 > 25:
                return 2

O exemplo acima é gerado com names = ['f'+str(j+1) for j in range(NUM_FEATURES)].

Um recurso útil é que ele pode gerar um tamanho de arquivo menor com espaçamento reduzido. Basta definir spacing=2.

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.