Skip to content

Instantly share code, notes, and snippets.

@suyash
Last active December 27, 2022 14:51
Show Gist options
  • Select an option

  • Save suyash/07b2ae4822f717d3edadb09a0f79ec57 to your computer and use it in GitHub Desktop.

Select an option

Save suyash/07b2ae4822f717d3edadb09a0f79ec57 to your computer and use it in GitHub Desktop.

Revisions

  1. suyash revised this gist Jan 27, 2019. 1 changed file with 2 additions and 4 deletions.
    6 changes: 2 additions & 4 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -58,10 +58,8 @@ Wyod nome so wat Prove fovers; inage are er, ont mangle poaves ny whien

    This project is licensed under either of

    * Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or
    http://www.apache.org/licenses/LICENSE-2.0)
    * MIT license ([LICENSE-MIT](LICENSE-MIT) or
    http://opensource.org/licenses/MIT)
    * Apache License, Version 2.0, (http://www.apache.org/licenses/LICENSE-2.0)
    * MIT license (http://opensource.org/licenses/MIT)

    at your option.

  2. suyash created this gist Jan 27, 2019.
    10 changes: 10 additions & 0 deletions Cargo.toml
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,10 @@
    [package]
    name = "min-char-rnn-rs"
    version = "0.1.0"
    authors = ["Suyash <[email protected]>"]
    edition = "2018"

    [dependencies]
    rulinalg = "0.4.2"
    rand = "0.6.4"
    indicatif = "0.11.0"
    70 changes: 70 additions & 0 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,70 @@
    # min-char-rnn-rs

    [@karpathy's min-char-rnn.py](https://gist.github.com/karpathy/d4dee566867f8291f086) in [Rust](src/main.rs).

    # Usage

    Same as the python version, pass a text file as the first CLI argument.

    ```
    cargo run -- ~/.datasets/shakespeare/shakespeare_100000/shakespear.txt
    ```

    # Result

    on [shakespeare_100000](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)

    ```
    iteration: 0, loss: 103.17835654878319
    sampled prediction:
    ----------
    T?unLDV.nSDVg-c.f-t.dzU-TCTElTLMT.NlSzMR!kOJrCy-nqk
    pEk,mJ
    slXWkW?YmaTGpJdtteqOJx-:WS?.Wq:TxKHuONM!o.fprDAflNZx'DDttzp Prq!KUsB
    zbL?dsQgmMUJRtA-:tjEHvWlxIcPavu-ia,fjageuTZ-lZJabSzOFnwnEwfjAnIY?XFl:fcUN
    ----------
    ...
    iteration: 6000, loss: 58.01707775597862
    sampled prediction:
    ----------
    nous jo per, be thall thout I anpe tineed ir wito.
    Ph thes
    Huve wa use muls nath ong-aI'st be me beaf.
    TUNYNTF:
    hous mutinf st arprery I dealy then.
    CORAUY TEK:
    he wist ntit Ip akt ce thee tais; fo
    ----------
    ...
    iteration: 12000, loss: 53.934933970894136
    sampled prediction:
    ----------
    ngouvy Bothet, anginnmes appat:
    She the thevisinvet, teet:
    Whe't Wit the ny ce shremall attfwmedy, pollds!
    Han thall,
    That gintk,
    Wyod nome so wat Prove fovers; inage are er, ont mangle poaves ny whien
    ----------
    ```

    # License

    This project is licensed under either of

    * Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or
    http://www.apache.org/licenses/LICENSE-2.0)
    * MIT license ([LICENSE-MIT](LICENSE-MIT) or
    http://opensource.org/licenses/MIT)

    at your option.

    ### Contribution

    Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in this work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions
    263 changes: 263 additions & 0 deletions main.rs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,263 @@
    //! [https://gist.github.com/karpathy/d4dee566867f8291f086](@karpathy's min-char-rnn.py) in Rust.
    use std::collections::{BTreeMap, BTreeSet};
    use std::env;
    use std::error::Error;
    use std::fs::File;
    use std::io::Read;
    use std::iter::FromIterator;
    use std::ops::Mul;

    use indicatif::{ProgressBar, ProgressStyle};
    use rand::distributions::{Distribution, Uniform, WeightedIndex};
    use rand::{thread_rng, Rng};
    use rulinalg::matrix::{BaseMatrix, Matrix};

    #[allow(non_snake_case, clippy::many_single_char_names)]
    fn main() -> Result<(), Box<dyn Error>> {
    let mut rng = thread_rng();

    let filename = env::args().nth(1).expect("Expected filename to be given");
    let mut f = File::open(filename)?;
    let mut text = String::new();
    f.read_to_string(&mut text)?;

    let chars = BTreeSet::from_iter(text.chars());
    let vocab_size = chars.len();
    let char_index = BTreeMap::from_iter(chars.iter().cloned().zip(0..vocab_size));
    let inverted_index = BTreeMap::from_iter(char_index.clone().into_iter().map(|(k, v)| (v, k)));

    dbg!(vocab_size);

    let text: Vec<usize> = text.chars().map(|c| char_index[&c]).collect();

    // hyperparameters

    let hidden_size = 100;
    let seq_length = 25;
    let learning_rate = 0.1;

    // model parameters

    let dist = Uniform::new(0.0, 1.0);

    // weights

    let mut W_ih = Matrix::new(
    vocab_size,
    hidden_size,
    rng.sample_iter(&dist)
    .take(vocab_size * hidden_size)
    .map(|v| v * 0.01)
    .collect::<Vec<f64>>(),
    );
    let mut W_hh = Matrix::new(
    hidden_size,
    hidden_size,
    rng.sample_iter(&dist)
    .take(hidden_size * hidden_size)
    .map(|v| v * 0.01)
    .collect::<Vec<f64>>(),
    );
    let mut W_hy = Matrix::new(
    hidden_size,
    vocab_size,
    rng.sample_iter(&dist)
    .take(hidden_size * vocab_size)
    .map(|v| v * 0.01)
    .collect::<Vec<f64>>(),
    );

    // biases

    let mut b_h = Matrix::zeros(1, hidden_size);
    let mut b_y = Matrix::zeros(1, vocab_size);

    // state

    let mut h = Matrix::zeros(1, hidden_size);

    // Memory Variables for Adagrad

    let mut m_W_ih = Matrix::zeros(vocab_size, hidden_size);
    let mut m_W_hh = Matrix::zeros(hidden_size, hidden_size);
    let mut m_W_hy = Matrix::zeros(hidden_size, vocab_size);
    let mut m_b_h = Matrix::zeros(1, hidden_size);
    let mut m_b_y = Matrix::zeros(1, vocab_size);

    // iteration variables

    let (mut n, mut p) = (0, 0);
    let mut smooth_loss = -(1.0 / (vocab_size as f64)).ln() * (seq_length as f64);

    let progress = ProgressBar::new_spinner();
    progress.set_style(ProgressStyle::default_spinner().template("{spinner:.green} {msg}"));

    loop {
    if p + seq_length + 1 >= text.len() || n == 0 {
    h = Matrix::zeros(1, hidden_size);
    p = 0;
    }

    // NOTE: instead of defining a lossFn, the core net is implemented here itself
    // a lot of things around it are in new functions

    let mut loss = 0.0;

    let mut ts = Vec::new();
    ts.push((None, h.clone(), None));

    for t in 0..seq_length {
    let input = text[p + t];
    let target = text[p + t + 1];

    let x = one_hot(input, vocab_size);

    h = (&x).mul(&W_ih) + h.mul(&W_hh) + &b_h;
    h = tanh(h);

    let y = (&h).mul(&W_hy) + &b_y;
    let y = softmax(y);

    loss += -y[[0, target]].ln();

    ts.push((Some(x), h.clone(), Some(y)));
    }

    let mut d_W_ih = Matrix::zeros(vocab_size, hidden_size);
    let mut d_W_hh = Matrix::zeros(hidden_size, hidden_size);
    let mut d_W_hy = Matrix::zeros(hidden_size, vocab_size);
    let mut d_b_h = Matrix::zeros(1, hidden_size);
    let mut d_b_y = Matrix::zeros(1, vocab_size);

    // backwards gradient for current state coming in from next state
    let mut d_h_next = Matrix::zeros(1, hidden_size);

    for t in (1..=seq_length).rev() {
    let target = text[p + t];

    let (x, h, y) = &ts[t];
    let mut dy = y.as_ref().unwrap().clone();
    dy[[0, target]] -= 1.0;

    d_W_hy += h.transpose().mul(&dy);
    d_b_y += &dy;

    let dh = dy.mul(W_hy.transpose()) + &d_h_next;
    let dh = dh.elemul(&tanh_derivative(h));

    let (_, prevh, _) = &ts[t - 1];

    d_W_hh += prevh.transpose().mul(&dh);
    d_b_h += &dh;

    d_W_ih += x.as_ref().unwrap().transpose().mul(&dh);

    d_h_next = dh.mul(W_hh.transpose());
    }

    // NOTE: skipping gradient clipping for now

    smooth_loss = smooth_loss * 0.999 + loss * 0.001;
    if n % 100 == 0 {
    progress.println(format!("iteration: {}, loss: {}", n, smooth_loss));

    let mut h_copy = h.clone();
    let mut value = text[p];
    let mut sample = String::new();
    sample.push(inverted_index[&value]);

    for _ in 0..200 {
    let x = one_hot(value, vocab_size);

    h_copy = tanh(x.mul(&W_ih) + (&h_copy).mul(&W_hh) + &b_h);

    let y = softmax((&h_copy).mul(&W_hy));
    let y = y.into_vec();

    let dist = WeightedIndex::new(&y)?;
    value = dist.sample(&mut rng);
    sample.push(inverted_index[&value]);
    }

    progress.println(format!(
    "sampled prediction:\n----------\n{}\n----------\n",
    sample
    ));
    }

    progress.set_message(&format!("loss: {}", smooth_loss));

    m_W_ih += (&d_W_ih).elemul(&d_W_ih);
    m_W_hh += (&d_W_hh).elemul(&d_W_hh);
    m_W_hy += (&d_W_hy).elemul(&d_W_hy);
    m_b_h += (&d_b_h).elemul(&d_b_h);
    m_b_y += (&d_b_y).elemul(&d_b_y);

    W_ih -= adagrad_update(d_W_ih, &m_W_ih, learning_rate);
    W_hh -= adagrad_update(d_W_hh, &m_W_hh, learning_rate);
    W_hy -= adagrad_update(d_W_hy, &m_W_hy, learning_rate);
    b_h -= adagrad_update(d_b_h, &m_b_h, learning_rate);
    b_y -= adagrad_update(d_b_y, &m_b_y, learning_rate);

    p += seq_length;
    n += 1;
    }
    }

    fn one_hot(v: usize, t: usize) -> Matrix<f64> {
    let mut data = vec![0.0; t];
    data[v] = 1.0;
    Matrix::new(1, t, data)
    }

    fn tanh(mut m: Matrix<f64>) -> Matrix<f64> {
    for i in 0..m.rows() {
    for j in 0..m.cols() {
    m[[i, j]] = m[[i, j]].tanh();
    }
    }

    m
    }

    fn softmax(mut m: Matrix<f64>) -> Matrix<f64> {
    for i in 0..m.rows() {
    let mut s = 0.0;

    for j in 0..m.cols() {
    m[[i, j]] = m[[i, j]].exp();
    s += m[[i, j]];
    }

    for j in 0..m.cols() {
    m[[i, j]] /= s;
    }
    }

    m
    }

    fn tanh_derivative(m: &Matrix<f64>) -> Matrix<f64> {
    let mut ans = Matrix::zeros(m.rows(), m.cols());

    for i in 0..m.rows() {
    for j in 0..m.cols() {
    ans[[i, j]] = 1.0 - (m[[i, j]] * m[[i, j]]);
    }
    }

    ans
    }

    fn adagrad_update(mut d: Matrix<f64>, m: &Matrix<f64>, learning_rate: f64) -> Matrix<f64> {
    let (r, c) = (m.rows(), m.cols());

    for i in 0..r {
    for j in 0..c {
    d[[i, j]] = (learning_rate * d[[i, j]]) / (m[[i, j]] + 1e-8).sqrt();
    }
    }

    d
    }