Skip to content

Instantly share code, notes, and snippets.

@btrude
Last active December 10, 2020 18:03
Show Gist options
  • Save btrude/3615a96d1c25ec8d9fb8bcaacf647e8b to your computer and use it in GitHub Desktop.
Save btrude/3615a96d1c25ec8d9fb8bcaacf647e8b to your computer and use it in GitHub Desktop.

Revisions

  1. btrude revised this gist Sep 21, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion sample_extended.py
    Original file line number Diff line number Diff line change
    @@ -121,7 +121,7 @@ def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps):
    levels_map = { 1: 2, 2: 1, 3: 0 } # { hps.levels: level }
    if levels_map[hps.levels] == level and level:
    print(f'level {level} done and explicit exit invoked')
    exit()
    break
    return zs

    # Generate ancestral samples given a list of artists and genres
  2. btrude revised this gist Sep 12, 2020. No changes.
  3. btrude created this gist Sep 12, 2020.
    371 changes: 371 additions & 0 deletions sample_extended.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,371 @@
    import os
    import torch as t
    import jukebox.utils.dist_adapter as dist

    from jukebox.hparams import Hyperparams
    from jukebox.data.labels import EmptyLabeller
    from jukebox.utils.torch_utils import empty_cache
    from jukebox.utils.audio_utils import save_wav, load_audio
    from jukebox.make_models import make_model
    from jukebox.align import get_alignment
    from jukebox.save_html import save_html
    from jukebox.utils.sample_utils import split_batch, get_starts
    from jukebox.utils.dist_utils import print_once
    import fire

    # Sample a partial window of length<n_ctx with tokens_to_sample new tokens on level=level
    def sample_partial_window(zs, labels, sampling_kwargs, level, prior, tokens_to_sample, hps):
    z = zs[level]
    n_ctx = prior.n_ctx
    current_tokens = z.shape[1]
    if current_tokens < n_ctx - tokens_to_sample:
    sampling_kwargs['sample_tokens'] = current_tokens + tokens_to_sample
    start = 0
    else:
    sampling_kwargs['sample_tokens'] = n_ctx
    start = current_tokens - n_ctx + tokens_to_sample

    return sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps)

    # Sample a single window of length=n_ctx at position=start on level=level
    def sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps):
    n_samples = hps.n_samples
    n_ctx = prior.n_ctx
    end = start + n_ctx

    # get z already sampled at current level
    z = zs[level][:,start:end]

    if 'sample_tokens' in sampling_kwargs:
    # Support sampling a window shorter than n_ctx
    sample_tokens = sampling_kwargs['sample_tokens']
    else:
    sample_tokens = (end - start)
    conditioning_tokens, new_tokens = z.shape[1], sample_tokens - z.shape[1]

    print_once(f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on {conditioning_tokens} tokens")

    if new_tokens <= 0:
    # Nothing new to sample
    return zs

    # get z_conds from level above
    z_conds = prior.get_z_conds(zs, start, end)

    # set y offset, sample_length and lyrics tokens
    y = prior.get_y(labels, start)

    empty_cache()

    max_batch_size = sampling_kwargs['max_batch_size']
    del sampling_kwargs['max_batch_size']


    z_list = split_batch(z, n_samples, max_batch_size)
    z_conds_list = split_batch(z_conds, n_samples, max_batch_size)
    y_list = split_batch(y, n_samples, max_batch_size)
    z_samples = []
    for z_i, z_conds_i, y_i in zip(z_list, z_conds_list, y_list):
    z_samples_i = prior.sample(n_samples=z_i.shape[0], z=z_i, z_conds=z_conds_i, y=y_i, **sampling_kwargs)
    z_samples.append(z_samples_i)
    z = t.cat(z_samples, dim=0)

    sampling_kwargs['max_batch_size'] = max_batch_size

    # Update z with new sample
    z_new = z[:,-new_tokens:]
    zs[level] = t.cat([zs[level], z_new], dim=1)
    return zs

    # Sample total_length tokens at level=level with hop_length=hop_length
    def sample_level(zs, labels, sampling_kwargs, level, prior, total_length, hop_length, hps):
    print_once(f"Sampling level {level}")
    if total_length >= prior.n_ctx:
    for start in get_starts(total_length, prior.n_ctx, hop_length):
    zs = sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps)
    else:
    zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior, total_length, hps)
    return zs

    # Sample multiple levels
    def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps):
    alignments = None
    for level in reversed(sample_levels):
    prior = priors[level]
    prior.cuda()
    empty_cache()

    # Set correct total_length, hop_length, labels and sampling_kwargs for level
    assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}"
    total_length = hps.sample_length//prior.raw_to_tokens
    hop_length = int(hps.hop_fraction[level]*prior.n_ctx)
    zs = sample_level(zs, labels[level], sampling_kwargs[level], level, prior, total_length, hop_length, hps)

    prior.cpu()
    empty_cache()

    # Decode sample
    x = prior.decode(zs[level:], start_level=level, bs_chunks=zs[level].shape[0])

    if dist.get_world_size() > 1:
    logdir = f"{hps.name}_rank_{dist.get_rank()}/level_{level}"
    else:
    logdir = f"{hps.name}/level_{level}"
    if not os.path.exists(logdir):
    os.makedirs(logdir)
    t.save(dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar")
    save_wav(logdir, x, hps.sr)
    if alignments is None and priors[-1] is not None and priors[-1].n_tokens > 0 and not isinstance(priors[-1].labeller, EmptyLabeller):
    alignments = get_alignment(x, zs, labels[-1], priors[-1], sampling_kwargs[-1]['fp16'], hps)

    levels_map = { 1: 2, 2: 1, 3: 0 } # { hps.levels: level }
    if levels_map[hps.levels] == level and level:
    print(f'level {level} done and explicit exit invoked')
    exit()
    return zs

    # Generate ancestral samples given a list of artists and genres
    def ancestral_sample(labels, sampling_kwargs, priors, hps):
    sample_levels = list(range(len(priors)))
    zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))]
    zs = _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps)
    return zs

    # Continue ancestral sampling from previously saved codes
    def continue_sample(zs, labels, sampling_kwargs, priors, hps):
    sample_levels = list(range(len(priors)))
    zs = _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps)
    return zs

    # Upsample given already generated upper-level codes
    def upsample(zs, labels, sampling_kwargs, priors, hps):
    sample_levels = list(range(len(priors) - 1))
    zs = _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps)
    return zs

    # Prompt the model with raw audio input (dimension: NTC) and generate continuations
    def primed_sample(x, labels, sampling_kwargs, priors, hps):
    sample_levels = list(range(len(priors)))
    zs = priors[-1].encode(x, start_level=0, end_level=len(priors), bs_chunks=x.shape[0])
    zs = _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps)
    return zs

    # Load `duration` seconds of the given audio files to use as prompts
    def load_prompts(audio_files, duration, hps):
    xs = []
    for audio_file in audio_files:
    x = load_audio(audio_file, sr=hps.sr, duration=duration, offset=0.0, mono=True)
    x = x.T # CT -> TC
    xs.append(x)
    while len(xs) < hps.n_samples:
    xs.extend(xs)
    xs = xs[:hps.n_samples]
    x = t.stack([t.from_numpy(x) for x in xs])
    x = x.to('cuda', non_blocking=True)
    return x

    def match_n_samples(zs, n_samples):
    for i, z in enumerate(zs):
    if z.shape[0] != n_samples:
    print(f"Expected bs = {n_samples}, got {zs[i].shape[0]}")
    if zs[i].shape[0] > n_samples:
    print(f"Truncating samples to match expected shape for level {i}")
    zs[i] = zs[i][:n_samples]
    else:
    print(f'Extending current samples to match expected for level {i}')
    extension = [z for z in zs[i]]
    while len(extension) <= n_samples:
    extension.extend(extension)
    extension = extension[:n_samples]
    zs[i] = t.stack(extension)
    return zs

    # Load codes from previous sampling run
    def load_codes(codes_file, duration, priors, hps):
    data = t.load(codes_file, map_location='cpu')
    zs = [z.cuda() for z in data['zs']]
    n_samples = hps.n_samples

    if hps.get('pref_codes'):
    codes = [c for c in hps.pref_codes]
    while (len(codes) < n_samples):
    codes.extend(codes)
    codes = codes[:n_samples]

    codes = [zs[-1][c] for c in codes]
    zs[-1] = t.stack(codes)

    zs = match_n_samples(zs, n_samples)

    assert zs[-1].shape[0] == hps.n_samples, f"Expected bs = {hps.n_samples}, got {zs[-1].shape[0]}"
    del data
    if duration is not None:
    # Cut off codes to match duration
    top_raw_to_tokens = priors[-1].raw_to_tokens
    assert duration % top_raw_to_tokens == 0, f"Cut-off duration {duration} not an exact multiple of top_raw_to_tokens"
    assert duration//top_raw_to_tokens <= zs[-1].shape[1], f"Cut-off tokens {duration//priors[-1].raw_to_tokens} longer than tokens {zs[-1].shape[1]} in saved codes"
    zs = [z[:,:duration//prior.raw_to_tokens] for z, prior in zip(zs, priors)]
    return zs

    # Generate and save samples, alignment, and webpage for visualization.
    def save_samples(model, device, hps, sample_hps):
    print(hps)
    from jukebox.lyricdict import poems, gpt_2_lyrics
    vqvae, priors = make_model(model, device, hps)

    assert hps.sample_length//priors[-2].raw_to_tokens >= priors[-2].n_ctx, f"Upsampling needs atleast one ctx in get_z_conds. Please choose a longer sample length"

    total_length = hps.total_sample_length_in_seconds * hps.sr
    offset = 0

    # Set artist/genre/lyrics for your samples here!
    # We used different label sets in our models, but you can write the human friendly names here and we'll map them under the hood for each model.
    # For the 5b/5b_lyrics model and the upsamplers, labeller will look up artist and genres in v2 set. (after lowercasing, removing non-alphanumerics and collapsing whitespaces to _).
    # For the 1b_lyrics top level, labeller will look up artist and genres in v3 set (after lowercasing).

    l2_meta_artist = hps.get('l2_meta_artist', 'unknown')
    l2_meta_genre = hps.get('l2_meta_genre', 'unknown')
    l2_meta_lyrics = hps.get('l2_meta_lyrics', '')
    metas = [
    {
    "artist": l2_meta_artist,
    "genre": l2_meta_genre,
    "lyrics": l2_meta_lyrics,
    "total_length": total_length,
    "offset": offset,
    },
    ]
    while len(metas) < hps.n_samples:
    metas.extend(metas)
    metas = metas[:hps.n_samples]

    l1_meta_artist = hps.get('l1_meta_artist', 'unknown')
    l1_meta_genre = hps.get('l1_meta_genre', 'unknown')
    l1_meta_lyrics = hps.get('l1_meta_lyrics', '')
    metas_u1 = [
    {
    "artist": l1_meta_artist,
    "genre": l1_meta_genre,
    "lyrics": l1_meta_lyrics,
    "total_length": total_length,
    "offset": offset,
    }
    ]
    while len(metas_u1) < hps.n_samples:
    metas_u1.extend(metas_u1)
    metas_u1 = metas_u1[:hps.n_samples]

    l0_meta_artist = hps.get('l0_meta_artist', 'unknown')
    l0_meta_genre = hps.get('l0_meta_genre', 'unknown')
    l0_meta_lyrics = hps.get('l0_meta_lyrics', '')
    metas_u0 = [
    {
    "artist": l0_meta_artist,
    "genre": l0_meta_genre,
    "lyrics": l0_meta_lyrics,
    "total_length": total_length,
    "offset": offset,
    }
    ]
    while len(metas_u0) < hps.n_samples:
    metas_u0.extend(metas_u0)
    metas_u0 = metas_u0[:hps.n_samples]

    #cond_level := {3: 2, 2: 1, 1: 0}
    labels = []
    for prior in priors:
    clevel = prior.__dict__['cond_level']
    if clevel == 3:
    labels.append(prior.labeller.get_batch_labels(metas, 'cuda'))
    elif clevel == 2:
    labels.append(prior.labeller.get_batch_labels(metas_u1, 'cuda'))
    elif clevel == 1:
    labels.append(prior.labeller.get_batch_labels(metas_u0, 'cuda'))

    for label in labels:
    assert label['y'].shape[0] == hps.n_samples

    lower_level_chunk_size = 32
    lower_level_max_batch_size = 16
    if model == '1b_lyrics':
    chunk_size = 32
    max_batch_size = 16
    else:
    chunk_size = 16
    max_batch_size = 3
    temperature = hps.get('temperature', 0.99)
    l1_temperature = hps.get('l1_temperature', 1)
    l0_temperature = hps.get('l0_temperature', 1)
    sampling_kwargs = [dict(temp=l0_temperature, fp16=True, chunk_size=lower_level_chunk_size, max_batch_size=lower_level_max_batch_size),
    dict(temp=l1_temperature, fp16=True, chunk_size=lower_level_chunk_size, max_batch_size=lower_level_max_batch_size),
    dict(temp=temperature, fp16=True, chunk_size=chunk_size, max_batch_size=max_batch_size)]

    if sample_hps.mode == 'ancestral':
    ancestral_sample(labels, sampling_kwargs, priors, hps)
    elif sample_hps.mode in ['continue', 'upsample', 'truncate']:
    assert sample_hps.codes_file is not None
    top_raw_to_tokens = priors[-1].raw_to_tokens
    if sample_hps.prompt_length_in_seconds is not None:
    duration = (int(sample_hps.prompt_length_in_seconds * hps.sr) // top_raw_to_tokens) * top_raw_to_tokens
    else:
    duration = None
    zs = load_codes(sample_hps.codes_file, duration, priors, hps)
    if sample_hps.mode == 'continue':
    continue_sample(zs, labels, sampling_kwargs, priors, hps)
    elif sample_hps.mode == 'upsample':
    upsample(zs, labels, sampling_kwargs, priors, hps)
    elif sample_hps.mode == 'truncate':
    truncate(zs, labels, sampling_kwargs, priors, hps)
    elif sample_hps.mode == 'primed':
    assert sample_hps.audio_file is not None
    assert sample_hps.prompt_length_in_seconds is not None
    audio_files = sample_hps.audio_file.split(',')
    top_raw_to_tokens = priors[-1].raw_to_tokens
    duration = (int(sample_hps.prompt_length_in_seconds * hps.sr) // top_raw_to_tokens) * top_raw_to_tokens
    x = load_prompts(audio_files, duration, hps)
    primed_sample(x, labels, sampling_kwargs, priors, hps)
    else:
    raise ValueError(f'Unknown sample mode {sample_hps.mode}.')

    def truncate(zs, labels, sampling_kwargs, priors, hps):
    alignments = None
    levels_map = { 1: 2, 2: 1, 3: 0 } # { hps.levels: level }
    level = levels_map[hps.levels]
    prior = priors[level]
    prior.cuda()
    empty_cache()

    assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}"
    divisor = 24 if prior.n_ctx == 8192 else 17
    truncate_to_token = int((prior.n_ctx / divisor) * hps.sample_length_in_seconds)
    zs[level] = t.stack([z[:truncate_to_token] for z in zs[level]])

    prior.cpu()
    empty_cache()

    # Decode sample
    x = prior.decode(zs[level:], start_level=level, bs_chunks=zs[level].shape[0])

    if dist.get_world_size() > 1:
    logdir = f"{hps.name}_rank_{dist.get_rank()}/level_{level}"
    else:
    logdir = f"{hps.name}/level_{level}"
    if not os.path.exists(logdir):
    os.makedirs(logdir)
    t.save(dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar")
    save_wav(logdir, x, hps.sr)
    if alignments is None and priors[-1] is not None and priors[-1].n_tokens > 0 and not isinstance(priors[-1].labeller, EmptyLabeller):
    alignments = get_alignment(x, zs, labels[-1], priors[-1], sampling_kwargs[-1]['fp16'], hps)
    return zs

    def run(model, mode='ancestral', codes_file=None, audio_file=None, prompt_length_in_seconds=None, port=29500, **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams(**kwargs)
    sample_hps = Hyperparams(dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds))

    with t.no_grad():
    save_samples(model, device, hps, sample_hps)

    if __name__ == '__main__':
    fire.Fire(run)