Last active
September 6, 2023 08:32
-
-
Save cphyc/e42f815ddd78600366e03ff6024b4f9c to your computer and use it in GitHub Desktop.
Revisions
-
cphyc revised this gist
Sep 6, 2023 . 1 changed file with 1 addition and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -6,10 +6,9 @@ import h5py import numpy as np import yt from yt import mylog as logger from yt.fields.derived_field import ValidateSpatial yt.enable_parallelism() logger.setLevel(10) -
cphyc created this gist
Sep 6, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,209 @@ import argparse import gc from pathlib import Path from typing import List, Optional, Sequence, Tuple import h5py import numpy as np import yt from astrophysics_toolset.utilities.logging import logger from yt.fields.derived_field import ValidateSpatial yt.enable_parallelism() logger.setLevel(10) def setup_dataset(ds): @yt.particle_filter(requires=["particle_family"], filtered_type="io") def tracers(pfilter, data): return data[(pfilter.filtered_type, "particle_family")] <= 0 def _velocity_dispersion(field, data): from itertools import product new_field = np.zeros_like(data["gas", "velocity_x"]) for i, j, k in product(*[range(2)] * 3): v = 0 w = np.ones_like(data["gas", "density"][i : i + 3, j : j + 3, k : k + 3]) for kk in "xyz": vel_block = data["gas", f"velocity_{kk}"][ i : i + 3, j : j + 3, k : k + 3 ] vmean = np.average(vel_block, weights=w, axis=(0, 1, 2)) v += np.average((vel_block - vmean) ** 2, weights=w, axis=(0, 1, 2)) new_field[i + 1, j + 1, k + 1] = np.sqrt(v) return data.apply_units(new_field, data["gas", "velocity_x"].units) ds.add_field( ("gas", "velocity_dispersion"), _velocity_dispersion, sampling_type="cell", validators=[ValidateSpatial(ghost_zones=1)], units="cm/s", ) ds.add_particle_filter("tracers") mesh_fields = [ *(("gas", f"velocity_{k}") for k in "xyz"), ("gas", "density"), ("gas", "temperature"), *( ("gas", f"{species}_number_density") for species in ("HI", "HII", "HeI", "HeII") ), ("gas", "sound_speed"), ("gas", "velocity_dispersion"), ("gas", "cooling_total"), ("gas", "heating_total"), ("gas", "cooling_net"), *(("gas", f"vorticity_{k}") for k in "xyz"), ] for field in mesh_fields: ds.add_mesh_sampling_particle_field(field, ptype="tracers") def extract_data( ds, fields: List[Tuple[str, str]], ): out_folder = (Path(ds.directory).parent / "subset").resolve() name = f"{str(ds)}_region.h5" out_filename = out_folder / name found_fields = [] if out_filename.exists(): found_fields = [] with h5py.File(out_filename, "r") as f: for ft in f: for fn in f[ft]: found_fields.append((ft, fn)) # Now check all fields have been registered missing = [f for f in fields if f not in found_fields] if len(missing) > 0: logger.info( "Found data file %s, but missing %s fields", out_filename, len(missing) ) else: logger.info("Found data file %s, all fields found", out_filename) return logger.info("Extracting data from %s", ds) setup_dataset(ds) ad = ds.all_data() yt.funcs.mylog.info("Computing cell indices") ad["tracers", "cell_index"] yt.funcs.mylog.info("Writing dataset into %s", out_filename) out_filename.parent.mkdir(parents=True, exist_ok=True) ad.save_as_dataset(str(out_filename), fields=fields) del ad, ds gc.collect() def main(argv: Optional[Sequence] = None) -> int: parser = argparse.ArgumentParser() parser.add_argument( "-i", "--simulation", required=True, type=str, help="Path to the folder containing all outputs.", ) parser.add_argument( "--output-slice", default=None, help=( "Slices of the output to consider (in the form istart:iend:istep), " "useful when parallelizing manually (default: %(default)s)." ), ) parser.add_argument( "--bbox", nargs=6, default=[0, 0, 0, 1, 1, 1], type=float, help=( "Bounding box to use for the region to extract in box units " "(default: %(default)s)." ), ) args = parser.parse_args(argv) # Build field list fields = [ ("tracers", "particle_family"), ("tracers", "particle_mass"), ("tracers", "particle_identity"), *[("tracers", f"particle_position_{k}") for k in "xyz"], *[("tracers", f"particle_velocity_{k}") for k in "xyz"], ("tracers", "particle_position"), *[("tracers", f"cell_gas_velocity_{k}") for k in "xyz"], ("tracers", "cell_gas_density"), ("tracers", "cell_gas_temperature"), ("tracers", "cell_gas_HI_number_density"), ("tracers", "cell_gas_HII_number_density"), ("tracers", "cell_gas_HeI_number_density"), ("tracers", "cell_gas_HeII_number_density"), ("tracers", "cell_gas_sound_speed"), ("tracers", "cell_gas_velocity_dispersion"), ("tracers", "cell_gas_cooling_total"), ("tracers", "cell_gas_heating_total"), ("tracers", "cell_gas_cooling_net"), *[("tracers", f"cell_gas_vorticity_{k}") for k in "xyz"], ] # Loop over all datasets simu = Path(args.simulation) outputs = [ out for out in sorted( list(simu.glob("output_?????")) + list(simu.glob("output_?????.tar.gz")) ) if not ( out.name.endswith(".tar.gz") and out.with_name(out.name.replace(".tar.gz", "")).exists() ) ] if args.output_slice is not None: istart, iend, istep = ( int(i) if i else None for i in args.output_slice.split(":") ) sl = slice(istart, iend, istep) outputs = outputs[sl] pbar = yt.funcs.get_pbar("Constructing trajectory information", len(outputs)) yt.set_log_level(40) bbox = [args.bbox[0:3], args.bbox[3:6]] for output in yt.parallel_objects(list(reversed(outputs))): if output.name.endswith(".tar.gz"): original_name = output.name.replace(".tar.gz", "") try: ds = yt.load_archive(output, original_name, mount_timeout=5, bbox=bbox) except yt.utilities.exceptions.YTUnidentifiedDataType: continue ds.directory = str(output.parent / original_name) else: try: ds = yt.load(output, bbox=bbox) except yt.utilities.exceptions.YTUnidentifiedDataType: continue extract_data(ds, fields) pbar.update(outputs.index(output)) return 0 if __name__ == "__main__": import sys sys.exit(main(sys.argv[1:]))