264 lines
8.2 KiB
Python
264 lines
8.2 KiB
Python
# uplim/management/commands/set_contaminated.py
|
|
|
|
# add custom flux-radius mapping?
|
|
# add specifying the columns?
|
|
# do contamination setting per survey?
|
|
# include nside for surveys?
|
|
|
|
from django.core.management.base import BaseCommand
|
|
from django.db import transaction
|
|
from uplim.models import Pixel, CatalogSource
|
|
|
|
import pandas as pd
|
|
import healpy as hp
|
|
import numpy as np
|
|
from astropy.coordinates import SkyCoord
|
|
|
|
from itertools import islice
|
|
from datetime import datetime
|
|
|
|
|
|
BATCH_SIZE = 900
|
|
|
|
|
|
def batch(iterable, size):
|
|
iterable = iter(iterable)
|
|
while True:
|
|
chunk = list(islice(iterable, size))
|
|
if not chunk:
|
|
break
|
|
yield chunk
|
|
|
|
|
|
class Command(BaseCommand):
|
|
help = "Set the 'contaminated' flag for all pixels based on the fluxes in the provided catalog."
|
|
|
|
# COMMAND LINE ARGUMENTS
|
|
# **********************
|
|
|
|
def add_arguments(self, parser):
|
|
|
|
parser.add_argument(
|
|
"--catalog", type=str, required=False, help="Path to the catalog.dat file"
|
|
)
|
|
|
|
# parser.add_argument(
|
|
# '--survey',
|
|
# type=int,
|
|
# required=False,
|
|
# help='integer number of the survey to set the flag for'
|
|
# )
|
|
|
|
parser.add_argument(
|
|
"--reset",
|
|
action="store_true",
|
|
default=False,
|
|
help="Reset the contamination flag across all pixels back to False.",
|
|
)
|
|
|
|
def handle(self, *args, **options):
|
|
|
|
# RESET BEHAVIOR: SET CONTAMINATION FLAG TO FALSE FOR ALL PIXELS
|
|
# **************************************************************
|
|
|
|
if options["reset"]:
|
|
|
|
self.stdout.write("Resetting the contamination flag...")
|
|
|
|
Pixel.objects.update(contaminated=False)
|
|
|
|
self.stdout.write("Done")
|
|
return
|
|
|
|
if not options["catalog"]:
|
|
self.stdout.write("No catalog file provided, exiting")
|
|
return
|
|
|
|
catalog_file = options["catalog"]
|
|
|
|
self.stdout.write(f"Catalog file:\t{catalog_file}")
|
|
|
|
# READ THE CATALOG FILE USING PANDAS READ_FWF
|
|
# *******************************************
|
|
|
|
# Define column positions based on the byte ranges
|
|
colspecs = [
|
|
(0, 4), # SrcID (1-4)
|
|
(5, 26), # Name (6-26)
|
|
(27, 37), # RAdeg (28-37)
|
|
(38, 48), # DEdeg (39-48)
|
|
(49, 55), # ePos (50-55)
|
|
(56, 63), # Signi (57-63)
|
|
(64, 76), # Flux (65-76)
|
|
(77, 89), # e_Flux (78-89)
|
|
(90, 118), # CName (91-118)
|
|
(119, 120), # NewXray (120)
|
|
(121, 134), # Type (122-134)
|
|
]
|
|
|
|
# Define column names
|
|
colnames = [
|
|
"SrcID",
|
|
"Name",
|
|
"RAdeg",
|
|
"DEdeg",
|
|
"ePos",
|
|
"Signi",
|
|
"Flux",
|
|
"e_Flux",
|
|
"CName",
|
|
"NewXray",
|
|
"Type",
|
|
]
|
|
|
|
# Read the file using the fixed-width format
|
|
catalog = pd.read_fwf(catalog_file, colspecs=colspecs, names=colnames)
|
|
|
|
for col in ["Name", "CName", "Type"]:
|
|
catalog[col] = catalog[col].fillna("")
|
|
|
|
self.stdout.write(str(catalog.head()))
|
|
|
|
# LOAD THE CATALOG INTO THE DATABASE
|
|
# **********************************
|
|
|
|
existing_srcids = set(CatalogSource.objects.values_list("srcid", flat=True))
|
|
|
|
to_create = []
|
|
|
|
for _, row in catalog.iterrows():
|
|
|
|
srcid = int(row["SrcID"])
|
|
if srcid in existing_srcids:
|
|
continue
|
|
to_create.append(
|
|
CatalogSource(
|
|
srcid=srcid,
|
|
name=row["Name"].strip(),
|
|
ra_deg=float(row["RAdeg"]),
|
|
dec_deg=float(row["DEdeg"]),
|
|
pos_error=float(row["ePos"]),
|
|
significance=float(row["Signi"]),
|
|
flux=float(row["Flux"]),
|
|
flux_error=float(row["e_Flux"]),
|
|
catalog_name=row["CName"].strip(),
|
|
new_xray=bool(int(row["NewXray"])),
|
|
source_type=row["Type"].strip(),
|
|
)
|
|
)
|
|
|
|
if to_create:
|
|
self.stdout.write(f"Inserting {len(to_create)} new catalog rows.")
|
|
for chunk in batch(to_create, BATCH_SIZE):
|
|
CatalogSource.objects.bulk_create(chunk, ignore_conflicts=True)
|
|
self.stdout.write("Catalog update complete.")
|
|
else:
|
|
self.stdout.write("All catalog rows already exist in the database.")
|
|
|
|
# hard coded nside and flux-radius mapping
|
|
# maybe change
|
|
|
|
nside = 4096
|
|
npix = hp.nside2npix(nside)
|
|
|
|
flux_bins = [0, 125, 250, 2000, 20000, np.inf] # define bin edges
|
|
mask_radii_deg = [
|
|
0.06,
|
|
0.15,
|
|
0.5,
|
|
0.9,
|
|
2.5,
|
|
] # corresponding mask radii in degrees
|
|
|
|
# Convert mask radii from degrees to radians (required by query_disc)
|
|
mask_radii = [np.radians(r) for r in mask_radii_deg]
|
|
|
|
# Use pandas.cut to assign each source a bin index (0, 1, or 2)
|
|
catalog["flux_bin"] = pd.cut(catalog["Flux"], bins=flux_bins, labels=False)
|
|
|
|
# manually add and change some sources
|
|
manual_additions = pd.DataFrame(
|
|
[
|
|
{"RAdeg": 279.9804336, "DEdeg": 5.0669542, "flux_bin": 3},
|
|
{"RAdeg": 266.5173685, "DEdeg": -29.1252321, "flux_bin": 3},
|
|
{
|
|
"RAdeg": 194.9350000,
|
|
"DEdeg": 27.9124722,
|
|
"flux_bin": 4,
|
|
}, # Coma Cluster, 2.5 degrees
|
|
{
|
|
"RAdeg": 187.6991667,
|
|
"DEdeg": 12.3852778,
|
|
"flux_bin": 4,
|
|
}, # Virgo cluster, 2.5 degrees
|
|
]
|
|
)
|
|
|
|
catalog = pd.concat([catalog, manual_additions], ignore_index=True)
|
|
|
|
catalog.loc[catalog["SrcID"] == 1101, "flux_bin"] = 2
|
|
|
|
mask_array = np.ones(npix, dtype=bool)
|
|
|
|
masked_pixels_set = set()
|
|
|
|
self.stdout.write("\nCreating a list of contaminated pixels...")
|
|
|
|
# process each source in the catalog
|
|
for _, row in catalog.iterrows():
|
|
|
|
ra = row["RAdeg"]
|
|
dec = row["DEdeg"]
|
|
|
|
src_coord = SkyCoord(ra, dec, unit="deg", frame="icrs")
|
|
gal = src_coord.galactic
|
|
|
|
ra, dec = gal.l.deg, gal.b.deg
|
|
|
|
flux_bin = row["flux_bin"] # 0, 1, or 2
|
|
# Get the corresponding mask radius (in radians) for this flux bin
|
|
radius = mask_radii[flux_bin]
|
|
|
|
# Convert (ra, dec) to HEALPix spherical coordinates
|
|
theta = np.radians(90.0 - dec)
|
|
phi = np.radians(ra)
|
|
vec = hp.ang2vec(theta, phi)
|
|
|
|
# Query all pixels within the given radius
|
|
# 'inclusive=True' makes sure pixels on the edge are included
|
|
pix_indices = hp.query_disc(nside, vec, radius, inclusive=True)
|
|
|
|
# Mark these pixels as bad (False) in our mask
|
|
mask_array[pix_indices] = False
|
|
# Add the pixel indices to our set of masked pixels
|
|
masked_pixels_set.update(pix_indices)
|
|
|
|
# Convert the set of masked pixels to a sorted list.
|
|
masked_pixels_list = sorted(list(masked_pixels_set))
|
|
|
|
# print("Number of masked pixels:", len(masked_pixels_list))
|
|
|
|
self.stdout.write("\nList ready, updating the database...")
|
|
|
|
if not masked_pixels_list:
|
|
self.stdout.write("No pixels marked as contaminated, exiting.")
|
|
return
|
|
|
|
total = len(masked_pixels_list)
|
|
updated = 0
|
|
self.stdout.write(f"\nUpdating contaminated flag in batches of {BATCH_SIZE}")
|
|
|
|
for chunk in batch(masked_pixels_list, BATCH_SIZE):
|
|
with transaction.atomic():
|
|
Pixel.objects.filter(hpid__in=chunk).update(contaminated=True)
|
|
|
|
updated += len(chunk)
|
|
percentage = updated / total * 100
|
|
|
|
timestamp = datetime.now().strftime("%H:%M:%S")
|
|
self.stdout.write(
|
|
f"[{timestamp}] {updated}/{total} ({percentage:.1f}%) updated"
|
|
)
|
|
|
|
self.stdout.write(f"\n Marked {updated} pixels as contaminated.")
|