ridge/scripts/01_bgdmodel.py
2024-04-13 15:31:40 +03:00

145 lines
4.9 KiB
Python
Executable File

#!/usr/bin/env python
__author__ = "Roman Krivonos"
__copyright__ = "Space Research Institute (IKI)"
import numpy as np
import pandas as pd
from astropy.io import fits
import matplotlib.pyplot as plt
import math, sys
import pickle
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import HuberRegressor
from sklearn.linear_model import RANSACRegressor
from sklearn.linear_model import TheilSenRegressor
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedKFold
#from statsmodels.robust.scale import huber
from astropy.stats import sigma_clip
from numpy import absolute
from numpy import arange
from ridge.utils import *
from ridge.config import *
enkey = sys.argv[1]
fn="detcnts.{}.fits".format(enkey)
print("Reading {}".format(datadir+fn))
d = fits.getdata(datadir+fn)
df=pd.DataFrame(np.array(d).byteswap().newbyteorder())
print(df.columns)
plotme=False
sigma=3
ntotal=0
nrev=0
bgdmodel={}
ignored_scw=[]
for rev in range(revmin,revmax):
# if not (rev==341):
# continue
df0 = df.query('CLEAN > 0.0 & REV == {} & abs(LAT) > {} & abs(LON) > {} & PHASE > {} & PHASE < {}'.format(rev,bmax,lmax,phmin,phmax))
nobs=len(df0)
if not(nobs):
continue
print("*** REV {} ***".format(rev))
ntotal=ntotal+nobs
nrev=nrev+1
phase_diff0=max(df0['PHASE'])-min(df0['PHASE'])
x = np.array(df0['PHASE'].values)
y = np.array(df0['CLEAN'].values)
scw = np.array(df0['OBSID'].values)
if(nobs >= nmax):
print("Phase diff: {} max allowed: {}".format(phase_diff0, phase_diff))
if(phase_diff0 > phase_diff):
c = 0
""" run regression """
print("*** Run regression for {}".format(rev))
#print(df0['CLEAN'].values)
x = x.reshape((-1, 1))
# https://machinelearningmastery.com/robust-regression-for-machine-learning-in-python/
#model = LinearRegression()
#model = HuberRegressor()
#model = RANSACRegressor()
model = TheilSenRegressor()
results = evaluate_model(x, y, model)
a,b,err = plot_best_fit(x, y, model)
if(plotme):
plot_ab(x, y, a, b, err, title="REGRESSION rev {}".format(rev))
else:
c = 1
a = 0.0
try:
filtered_data = sigma_clip(y, sigma=sigma, maxiters=10, return_bounds=True)
filtered_y = filtered_data[0]
filtered_min = filtered_data[1]
filtered_max = filtered_data[2]
b = np.mean(filtered_y)
for s in scw[y > filtered_max]:
ignored_scw.append(s.decode('UTF-8'))
for s in scw[y < filtered_min]:
ignored_scw.append(s.decode('UTF-8'))
except:
b = np.mean(y)
print("case 1: mean {}, normal mean {}".format(b,np.mean(y)))
err = np.sqrt(np.sum((y-b)**2))/len(y)
if(plotme):
plot_ab(x, y, a, b, err, title="Case 1, MEAN rev {}".format(rev))
elif(nobs > nmin and nobs < nmax):
a = 0.0
c = 2
try:
filtered_data = sigma_clip(y, sigma=sigma, maxiters=10, return_bounds=True)
filtered_y = filtered_data[0]
filtered_min = filtered_data[1]
filtered_max = filtered_data[2]
b = np.mean(filtered_y)
for s in scw[y > filtered_max]:
ignored_scw.append(s.decode('UTF-8'))
for s in scw[y < filtered_min]:
ignored_scw.append(s.decode('UTF-8'))
except:
b = np.mean(y)
print("case 2: mean {}, normal mean {}".format(b,np.mean(y)))
err = np.sqrt(np.sum((y-b)**2))/len(y)
if(plotme):
plot_ab(x, y, a, b, err, title="Case 2, MEAN rev {}".format(rev))
bgdmodel[rev]={'a':a, 'b':b, 'err':err, 'c':c}
print("Revs: {} Total obs.: {}".format(nrev,ntotal))
keys = list(bgdmodel.keys())
for rev in range(revmin,revmax):
if not (rev in keys):
left,right = find_nearest(keys, rev)
interp_a = np.interp(rev, [left,right], [bgdmodel[left]['a'], bgdmodel[right]['a']])
interp_b = np.interp(rev, [left,right], [bgdmodel[left]['b'], bgdmodel[right]['b']])
interp_err = np.interp(rev, [left,right], [bgdmodel[left]['err'], bgdmodel[right]['err']])
interp_c = np.interp(rev, [left,right], [bgdmodel[left]['c'], bgdmodel[right]['c']])
#print(rev, interp_a, interp_b, interp_err)
bgdmodel[rev]={'a':interp_a, 'b':interp_b, 'err':interp_err, 'c':interp_c}
with open(proddir+fn.replace(".fits",".pkl"), 'wb') as fp:
pickle.dump(bgdmodel, fp, protocol=pickle.HIGHEST_PROTOCOL)
with open(proddir+fn.replace(".fits",".ignored_scw.pkl"), 'wb') as fp:
pickle.dump(ignored_scw, fp, protocol=pickle.HIGHEST_PROTOCOL)
print("Removed ScWs:",ignored_scw)