Existe uma biblioteca que executaria regressão linear segmentada em python?


7

Existe um pacote nomeado segmentado em R. Existe um pacote semelhante em python?


Existe algo específico que você precisa? Para pontos de interrupção conhecidos, isso pode ser modelado apenas por uma interação com uma função de indicador (0 antes, 1 após o intervalo) ou um spline linear. A primeira abordagem tem um salto, a segunda abordagem resulta em uma linha de regressão por partes conectada.
Josef

Respostas:


7

Não, atualmente não há um pacote no Python que faça regressão linear segmentada tão minuciosamente quanto os do R (por exemplo, pacotes R listados nesta postagem do blog ). Como alternativa, você pode usar um algoritmo Bayesian Markov Chain Monte Carlo no Python para criar seu modelo segmentado.

A regressão linear segmentada, conforme implementada por todos os pacotes R no link acima, não permite restrições adicionais de parâmetros (ou seja, anteriores) e, como esses pacotes adotam uma abordagem frequente, o modelo resultante não fornece distribuições de probabilidade para o modelo parâmetros (ou seja, pontos de interrupção, declives, etc.). A definição de um modelo segmentado nos modelos estatísticos , que é freqüentista, é ainda mais restritivo porque o modelo requer um ponto de interrupção fixo da coordenada x.

Você pode projetar um modelo segmentado em Python usando o emcee do algoritmo Bayesian Markov Chain Monte Carlo . Jake Vanderplas escreveu um útil post e papel para saber como implementar emcee com comparações com pymc e PyStan.

Exemplo:

  • Modelo segmentado com dados:

Regressão segmentada

  • Distribuições de probabilidade de parâmetros de ajuste:

insira a descrição da imagem aqui


0

insira a descrição da imagem aqui

Esta é uma implementação própria.

import numpy as np
import matplotlib.pylab as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import LinearRegression

# parameters for setup
n_data = 20

# segmented linear regression parameters
n_seg = 3

np.random.seed(0)
fig, (ax0, ax1) = plt.subplots(1, 2)

# example 1
#xs = np.sort(np.random.rand(n_data))
#ys = np.random.rand(n_data) * .3 + np.tanh(5* (xs -.5))

# example 2
xs = np.linspace(-1, 1, 20)
ys = np.random.rand(n_data) * .3 + np.tanh(3*xs)

dys = np.gradient(ys, xs)

rgr = DecisionTreeRegressor(max_leaf_nodes=n_seg)
rgr.fit(xs.reshape(-1, 1), dys.reshape(-1, 1))
dys_dt = rgr.predict(xs.reshape(-1, 1)).flatten()

ys_sl = np.ones(len(xs)) * np.nan
for y in np.unique(dys_dt):
    msk = dys_dt == y
    lin_reg = LinearRegression()
    lin_reg.fit(xs[msk].reshape(-1, 1), ys[msk].reshape(-1, 1))
    ys_sl[msk] = lin_reg.predict(xs[msk].reshape(-1, 1)).flatten()
    ax0.plot([xs[msk][0], xs[msk][-1]],
             [ys_sl[msk][0], ys_sl[msk][-1]],
             color='r', zorder=1)

ax0.set_title('values')
ax0.scatter(xs, ys, label='data')
ax0.scatter(xs, ys_sl, s=3**2, label='seg lin reg', color='g', zorder=5)
ax0.legend()

ax1.set_title('slope')
ax1.scatter(xs, dys, label='data')
ax1.scatter(xs, dys_dt, label='DecisionTree', s=2**2)
ax1.legend()

plt.show()
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.