Skip to content

Instantly share code, notes, and snippets.

@cphyc
Last active September 6, 2023 08:32
Show Gist options
  • Save cphyc/e42f815ddd78600366e03ff6024b4f9c to your computer and use it in GitHub Desktop.
Save cphyc/e42f815ddd78600366e03ff6024b4f9c to your computer and use it in GitHub Desktop.

Revisions

  1. cphyc revised this gist Sep 6, 2023. 1 changed file with 1 addition and 2 deletions.
    3 changes: 1 addition & 2 deletions extract_tracer_data.py
    Original file line number Diff line number Diff line change
    @@ -6,10 +6,9 @@
    import h5py
    import numpy as np
    import yt
    from astrophysics_toolset.utilities.logging import logger
    from yt import mylog as logger
    from yt.fields.derived_field import ValidateSpatial


    yt.enable_parallelism()
    logger.setLevel(10)

  2. cphyc created this gist Sep 6, 2023.
    209 changes: 209 additions & 0 deletions extract_tracer_data.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,209 @@
    import argparse
    import gc
    from pathlib import Path
    from typing import List, Optional, Sequence, Tuple

    import h5py
    import numpy as np
    import yt
    from astrophysics_toolset.utilities.logging import logger
    from yt.fields.derived_field import ValidateSpatial


    yt.enable_parallelism()
    logger.setLevel(10)


    def setup_dataset(ds):
    @yt.particle_filter(requires=["particle_family"], filtered_type="io")
    def tracers(pfilter, data):
    return data[(pfilter.filtered_type, "particle_family")] <= 0

    def _velocity_dispersion(field, data):
    from itertools import product

    new_field = np.zeros_like(data["gas", "velocity_x"])
    for i, j, k in product(*[range(2)] * 3):
    v = 0
    w = np.ones_like(data["gas", "density"][i : i + 3, j : j + 3, k : k + 3])
    for kk in "xyz":
    vel_block = data["gas", f"velocity_{kk}"][
    i : i + 3, j : j + 3, k : k + 3
    ]

    vmean = np.average(vel_block, weights=w, axis=(0, 1, 2))
    v += np.average((vel_block - vmean) ** 2, weights=w, axis=(0, 1, 2))

    new_field[i + 1, j + 1, k + 1] = np.sqrt(v)

    return data.apply_units(new_field, data["gas", "velocity_x"].units)

    ds.add_field(
    ("gas", "velocity_dispersion"),
    _velocity_dispersion,
    sampling_type="cell",
    validators=[ValidateSpatial(ghost_zones=1)],
    units="cm/s",
    )

    ds.add_particle_filter("tracers")

    mesh_fields = [
    *(("gas", f"velocity_{k}") for k in "xyz"),
    ("gas", "density"),
    ("gas", "temperature"),
    *(
    ("gas", f"{species}_number_density")
    for species in ("HI", "HII", "HeI", "HeII")
    ),
    ("gas", "sound_speed"),
    ("gas", "velocity_dispersion"),
    ("gas", "cooling_total"),
    ("gas", "heating_total"),
    ("gas", "cooling_net"),
    *(("gas", f"vorticity_{k}") for k in "xyz"),
    ]
    for field in mesh_fields:
    ds.add_mesh_sampling_particle_field(field, ptype="tracers")


    def extract_data(
    ds,
    fields: List[Tuple[str, str]],
    ):
    out_folder = (Path(ds.directory).parent / "subset").resolve()
    name = f"{str(ds)}_region.h5"
    out_filename = out_folder / name

    found_fields = []
    if out_filename.exists():
    found_fields = []
    with h5py.File(out_filename, "r") as f:
    for ft in f:
    for fn in f[ft]:
    found_fields.append((ft, fn))

    # Now check all fields have been registered
    missing = [f for f in fields if f not in found_fields]

    if len(missing) > 0:
    logger.info(
    "Found data file %s, but missing %s fields", out_filename, len(missing)
    )
    else:
    logger.info("Found data file %s, all fields found", out_filename)
    return

    logger.info("Extracting data from %s", ds)
    setup_dataset(ds)

    ad = ds.all_data()

    yt.funcs.mylog.info("Computing cell indices")
    ad["tracers", "cell_index"]

    yt.funcs.mylog.info("Writing dataset into %s", out_filename)
    out_filename.parent.mkdir(parents=True, exist_ok=True)
    ad.save_as_dataset(str(out_filename), fields=fields)
    del ad, ds

    gc.collect()


    def main(argv: Optional[Sequence] = None) -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument(
    "-i",
    "--simulation",
    required=True,
    type=str,
    help="Path to the folder containing all outputs.",
    )
    parser.add_argument(
    "--output-slice",
    default=None,
    help=(
    "Slices of the output to consider (in the form istart:iend:istep), "
    "useful when parallelizing manually (default: %(default)s)."
    ),
    )
    parser.add_argument(
    "--bbox",
    nargs=6,
    default=[0, 0, 0, 1, 1, 1],
    type=float,
    help=(
    "Bounding box to use for the region to extract in box units "
    "(default: %(default)s)."
    ),
    )
    args = parser.parse_args(argv)

    # Build field list
    fields = [
    ("tracers", "particle_family"),
    ("tracers", "particle_mass"),
    ("tracers", "particle_identity"),
    *[("tracers", f"particle_position_{k}") for k in "xyz"],
    *[("tracers", f"particle_velocity_{k}") for k in "xyz"],
    ("tracers", "particle_position"),
    *[("tracers", f"cell_gas_velocity_{k}") for k in "xyz"],
    ("tracers", "cell_gas_density"),
    ("tracers", "cell_gas_temperature"),
    ("tracers", "cell_gas_HI_number_density"),
    ("tracers", "cell_gas_HII_number_density"),
    ("tracers", "cell_gas_HeI_number_density"),
    ("tracers", "cell_gas_HeII_number_density"),
    ("tracers", "cell_gas_sound_speed"),
    ("tracers", "cell_gas_velocity_dispersion"),
    ("tracers", "cell_gas_cooling_total"),
    ("tracers", "cell_gas_heating_total"),
    ("tracers", "cell_gas_cooling_net"),
    *[("tracers", f"cell_gas_vorticity_{k}") for k in "xyz"],
    ]

    # Loop over all datasets
    simu = Path(args.simulation)
    outputs = [
    out
    for out in sorted(
    list(simu.glob("output_?????")) + list(simu.glob("output_?????.tar.gz"))
    )
    if not (
    out.name.endswith(".tar.gz")
    and out.with_name(out.name.replace(".tar.gz", "")).exists()
    )
    ]
    if args.output_slice is not None:
    istart, iend, istep = (
    int(i) if i else None for i in args.output_slice.split(":")
    )
    sl = slice(istart, iend, istep)
    outputs = outputs[sl]
    pbar = yt.funcs.get_pbar("Constructing trajectory information", len(outputs))
    yt.set_log_level(40)
    bbox = [args.bbox[0:3], args.bbox[3:6]]
    for output in yt.parallel_objects(list(reversed(outputs))):
    if output.name.endswith(".tar.gz"):
    original_name = output.name.replace(".tar.gz", "")
    try:
    ds = yt.load_archive(output, original_name, mount_timeout=5, bbox=bbox)
    except yt.utilities.exceptions.YTUnidentifiedDataType:
    continue
    ds.directory = str(output.parent / original_name)
    else:
    try:
    ds = yt.load(output, bbox=bbox)
    except yt.utilities.exceptions.YTUnidentifiedDataType:
    continue

    extract_data(ds, fields)
    pbar.update(outputs.index(output))

    return 0


    if __name__ == "__main__":
    import sys

    sys.exit(main(sys.argv[1:]))