import asyncio import numpy as np from astropy.io import fits import matplotlib.pyplot as plt from astropy.wcs import WCS import tqdm from multiprocessing.pool import ThreadPool from chan_psf import solve_for_locations, solve_for_locations_eintp psfe = np.array([1.8, 1.9, 3.0, 4.0, 6.0, 7.0, 8.0, 9.0]) def prepare_psf(evt): """ find all unique psf for observation and load in single 3d data cuve return data cube with events slices indexes """ u, ui = np.unique(evt["psf_cube"], return_inverse=True) data = np.array([np.load(p[3:])[:, ::-1,::-1].copy() for p in u]) return data, ui def select_xychunksize(wcs, halfpsfsize=36./3600.): """ get wcs and find wcs pixel size of psf reach """ sizex = int(abs(halfpsfsize/wcs.wcs.cdelt[1])) + 2 sizey = int(abs(halfpsfsize/wcs.wcs.cdelt[0])) + 2 print(sizex, sizey) return sizex, sizey def read_wcs(h): """ read events wcs header """ w = WCS(naxis=2) w.wcs.ctype = [h["TCTYP11"], h["TCTYP12"]] w.wcs.crval = [h["TCRVL11"], h["TCRVL12"]] w.wcs.cdelt = [h["TCDLT11"], h["TCDLT12"]] w.wcs.crpix = [h["TCRPX11"], h["TCRPX12"]] w = WCS(w.to_header()) return w def create_neighboring_blocks(x, y, sizex, sizey): """ schematically all sky is splitted on squares, which are approximatelly ~ 10 times greater then the psf events for corresponding square are joined :: squer + diluttaion of psf reach coordinate system with events and all required coefficiets are fed to psf solver current psf size is 25*0.5 arcsec (with up to \sqrt(2) factor in case of worst rolls -> 36'' """ """ event list already contains x and y for each event """ iix = (x//sizex + 0.5).astype(int) iiy = (y//sizey + 0.5).astype(int) isx, isy = np.mgrid[-1:2:1, -1:2:1] oidx = np.repeat(np.arange(x.size), 9) xyu, iixy, xyc = np.unique(np.array([np.repeat(iix, 9) + np.tile(isx.ravel(), x.size), np.repeat(iiy, 9)+ np.tile(isy.ravel(), x.size)]), axis=1, return_counts=True, return_inverse=True) sord = np.argsort(iixy) return oidx[sord], xyu, xyc def make_srccount_and_detmap(emap, evt, h, wcs=None): psfdata, ui = prepare_psf(evt) if wcs is None: wcs = read_wcs(h) x, y = evt["x"], evt["y"] else: ewcs = read_wcs(h) x, y = wcs.all_world2pix(ewcs.all_pix2world(np.array([x, y]).T, 0), 0).T eidx = np.searchsorted(psfe*1e3, evt["ENERGY"]) eidx = np.maximum((evt["ENERGY"]/1000. - psfe[eidx])/(psfe[eidx + 1] - psfe[eidx]), 0.) sizex, sizey = select_xychunksize(wcs) iidx, xyu, cts = create_neighboring_blocks(x, y, sizex, sizey) cc = np.zeros(cts.size + 1, int) cc[1:] = np.cumsum(cts) cmap, pmap = np.zeros(emap.shape, float), np.zeros(emap.shape, float) #xe, ye, pk, roll, psfi = np.copy(evt["x"][iidx]), np.copy(evt["y"][iidx]), np.copy((evt["quant_eff"]/evt["bkg_model"])[iidx]), np.copy(evt["roll_pnt"][iidx]), np.copy(ui[iidx]) xe = np.copy(x[iidx]).astype(float) ye = np.copy(y[iidx]).astype(float) ee = np.copy(eidx[iidx]).astype(float) pk = np.copy(evt["src_spec"][iidx]/evt["bkg_spec"][iidx]).astype(float) roll = np.copy(np.deg2rad(evt["roll_pnt"][iidx])).astype(float) psfi = np.copy(ui[iidx]) yg, xg = np.mgrid[0:sizey:1, 0:sizex:1] def worker(ixys): i, (xs, ys) = ixys eloc = emap[ys*sizey:ys*sizey+sizey, xs*sizex:xs*sizex+sizex] mask = eloc > 0. xl = (xg[mask] + xs*sizex).astype(float) yl = (yg[mask] + ys*sizey).astype(float) ell = (eloc[mask]).astype(float) if np.any(mask): cr, pr = solve_for_locations_eintp(psfi[cc[i]:cc[i+1]], ee[cc[i]:cc[i + 1]], xe[cc[i]:cc[i+1]], ye[cc[i]:cc[i+1]], roll[cc[i]:cc[i+1]], pk[cc[i]:cc[i+1]], xl, yl, ell, psfdata) else: xl, yl, cr, pr = np.empty(0, int),np.empty(0, int),np.empty(0, float),np.empty(0, float) return xl.astype(int), yl.astype(int), cr, pr tpool = ThreadPool(8) for xl, yl, cr, pr in tqdm.tqdm(tpool.imap_unordered(worker, enumerate(xyu.T)), total=xyu.shape[1]): cmap[yl, xl] = cr pmap[yl, xl] = pr return wcs, cmap, pmap if __name__ == "__main__": p1 = fits.open("test.fits") #emap = fits.getdata("exp.map.gz") #np.full((8192, 8192), 10000.) emap = fits.getdata("eR_spec_asp_0.fits.gz") #np.full((8192, 8192), 10000.) wcs, cmap, pmap = make_srccount_and_detmap(emap, p1[1].data, p1[1].header) fits.HDUList([fits.PrimaryHDU(), fits.ImageHDU(pmap - cmap, header=p1[1].header), fits.ImageHDU(cmap, header=p1[1].header)]).writeto("tmap4.fits.gz", overwrite=True) #fits.ImageHDU(data=pmap, header=wcs.to_header()).writeto("tmap4.fits.gz", overwrite=True)