Biblioteca Python para regressão segmentada (também conhecida como regressão por partes)


16

Estou procurando uma biblioteca Python que possa executar regressão segmentada (também conhecida como regressão por partes) .

Exemplo :

insira a descrição da imagem aqui



Esta pergunta fornece um método para executar uma regressão por partes, definindo uma função e usando bibliotecas python padrão. stackoverflow.com/questions/29382903/…

Uma pergunta semelhante ( stackoverflow.com/questions/29382903/... ) e uma biblioteca útil para a regressão piecewise ( pypi.org/project/pwlf )
prashanth

Respostas:


7

numpy.piecewise posso fazer isso.

por partes (x, condlist, funclist, * args, ** kw)

Avalie uma função definida por partes.

Dado um conjunto de condições e funções correspondentes, avalie cada função nos dados de entrada sempre que sua condição for verdadeira.

Um exemplo é dado no SO aqui . Para completar, aqui está um exemplo:

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03])

def piecewise_linear(x, x0, y0, k1, k2):
    return np.piecewise(x, [x < x0, x >= x0], [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0])

p , e = optimize.curve_fit(piecewise_linear, x, y)
xd = np.linspace(0, 15, 100)
plt.plot(x, y, "o")
plt.plot(xd, piecewise_linear(xd, *p))

4

O método proposto por Vito MR Muggeo [1] é relativamente simples e eficiente. Funciona para um número especificado de segmentos e para uma função contínua. As posições dos pontos de interrupção são estimadas iterativamente , executando, para cada iteração, uma regressão linear segmentada, permitindo saltos nos pontos de interrupção. Dos valores dos saltos, as próximas posições do ponto de interrupção são deduzidas, até que não haja mais descontinuidade (saltos).

"o processo é iterado até possível convergência, o que geralmente não é garantido"

Em particular, a convergência ou o resultado pode depender da primeira estimativa dos pontos de interrupção.

Este é o método utilizado no R pacote segmentado .

Aqui está uma implementação em python:

import numpy as np
from numpy.linalg import lstsq

ramp = lambda u: np.maximum( u, 0 )
step = lambda u: ( u > 0 ).astype(float)

def SegmentedLinearReg( X, Y, breakpoints ):
    nIterationMax = 10

    breakpoints = np.sort( np.array(breakpoints) )

    dt = np.min( np.diff(X) )
    ones = np.ones_like(X)

    for i in range( nIterationMax ):
        # Linear regression:  solve A*p = Y
        Rk = [ramp( X - xk ) for xk in breakpoints ]
        Sk = [step( X - xk ) for xk in breakpoints ]
        A = np.array([ ones, X ] + Rk + Sk )
        p =  lstsq(A.transpose(), Y, rcond=None)[0] 

        # Parameters identification:
        a, b = p[0:2]
        ck = p[ 2:2+len(breakpoints) ]
        dk = p[ 2+len(breakpoints): ]

        # Estimation of the next break-points:
        newBreakpoints = breakpoints - dk/ck 

        # Stop condition
        if np.max(np.abs(newBreakpoints - breakpoints)) < dt/5:
            break

        breakpoints = newBreakpoints
    else:
        print( 'maximum iteration reached' )

    # Compute the final segmented fit:
    Xsolution = np.insert( np.append( breakpoints, max(X) ), 0, min(X) )
    ones =  np.ones_like(Xsolution) 
    Rk = [ c*ramp( Xsolution - x0 ) for x0, c in zip(breakpoints, ck) ]

    Ysolution = a*ones + b*Xsolution + np.sum( Rk, axis=0 )

    return Xsolution, Ysolution

Exemplo:

import matplotlib.pyplot as plt

X = np.linspace( 0, 10, 27 )
Y = 0.2*X  - 0.3* ramp(X-2) + 0.3*ramp(X-6) + 0.05*np.random.randn(len(X))
plt.plot( X, Y, 'ok' );

initialBreakpoints = [1, 7]
plt.plot( *SegmentedLinearReg( X, Y, initialBreakpoints ), '-r' );
plt.xlabel('X'); plt.ylabel('Y');

gráfico

[1]: Muggeo, VM (2003). Estimando modelos de regressão com pontos de interrupção desconhecidos. Estatísticas em medicina, 22 (19), 3055-3071.


3

Estive procurando a mesma coisa e, infelizmente, parece que não há uma no momento. Algumas sugestões de como proceder podem ser encontradas nesta pergunta anterior .

Como alternativa, você pode procurar em algumas bibliotecas R, por exemplo, segmentadas, SiZer, strucchange e, se algo funcionar, tente incorporar o código R em python com rpy2 .

Edição para adicionar um link ao py-earth , "Uma implementação em Python dos Splines de Regressão Adaptativa Multivariada de Jerome Friedman".


2

Há uma postagem no blog com uma implementação recursiva de regressão por partes. Essa solução se encaixa na regressão descontínua.

Se você está insatisfeito com o modelo descontínuo e deseja uma configuração contínua, proponho que você procure sua curva com base kem curvas em forma de L, usando Lasso para dispersão:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import Lasso
# generate data
np.random.seed(42)
x = np.sort(np.random.normal(size=100))
y_expected = 3 + 0.5 * x + 1.25 * x * (x>0)
y = y_expected + np.random.normal(size=x.size, scale=0.5)
# prepare a basis
k = 10
thresholds = np.percentile(x, np.linspace(0, 1, k+2)[1:-1]*100)
basis = np.hstack([x[:, np.newaxis],  np.maximum(0,  np.column_stack([x]*k)-thresholds)]) 
# fit a model
model = Lasso(0.03).fit(basis, y)
print(model.intercept_)
print(model.coef_.round(3))
plt.scatter(x, y)
plt.plot(x, y_expected, color = 'b')
plt.plot(x, model.predict(basis), color='k')
plt.legend(['true', 'predicted'])
plt.xlabel('x')
plt.ylabel('y')
plt.title('fitting segmented regression')
plt.show()

Este código retornará um vetor de coeficientes estimados para você:

[ 0.57   0.     0.     0.     0.     0.825  0.     0.     0.     0.     0.   ]

Devido à abordagem Lasso, é escasso: o modelo encontrou exatamente um ponto de interrupção entre 10 possíveis. Os números 0,57 e 0,825 correspondem a 0,5 e 1,25 no DGP verdadeiro. Embora não estejam muito próximas, as curvas ajustadas são:

insira a descrição da imagem aqui

Essa abordagem não permite estimar exatamente o ponto de interrupção. Mas se o seu conjunto de dados for grande o suficiente, você poderá jogar com diferentes k(talvez ajustá-lo por validação cruzada) e estimar o ponto de interrupção com precisão suficiente.

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.