Created
June 6, 2020 07:35
-
-
Save jerinphilip/aeefdee60a1aaf1de38c3798ead70cc1 to your computer and use it in GitHub Desktop.
Revisions
-
jerinphilip created this gist
Jun 6, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,126 @@ #!/bin/bash #SBATCH --partition long #SBATCH --account jerin #SBATCH --nodes 1 #SBATCH --ntasks 1 #SBATCH --gres gpu:4 #SBATCH --cpus-per-task 40 #SBATCH --mem-per-cpu 2G #SBATCH --time UNLIMITED #SBATCH --signal=B:HUP@900 #SBATCH --output logs/aleatoric-train-%j.out module load python/3.7.0 module load pytorch/1.1.0 IMPORTS=( ufal-en-tam.tar filtered-iitb.tar ilci.tar national-newscrawl.tar wat-ilmpc.tar cricket-all.tar ) LOCAL_ROOT="/ssd_scratch/cvit/$USER/$TAG" REMOTE_ROOT="ada:/share1/dataset/text" TAG=$1 CONFIG=$2 mkdir -p $LOCAL_ROOT/{data,checkpoints} DATA=$LOCAL_ROOT/data CHECKPOINTS=$LOCAL_ROOT/checkpoints REMOTE_DIR="ada:/share3/$USER/$TAG/checkpoints" function copy { for IMPORT in ${IMPORTS[@]}; do rsync --progress $REMOTE_ROOT/$IMPORT $DATA/ tar_args="$DATA/$IMPORT -C $DATA/" tar -df $tar_args 2> /dev/null || tar -kxvf $tar_args done rsync -rv --progress $REMOTE_DIR/checkpoint_last.pt $CHECKPOINTS/ if [ $? -ne 0 ]; then echo "Copying base model"; rsync -rv --progress ada:/share1/jerin/ilmulti-checkpoints/mm-all-aleatoric-base.pt $CHECKPOINTS/checkpoint_last.pt fi } export ILMULTI_CORPUS_ROOT=$DATA # MODEL="transformer_vaswani_wmt_en_fr_big" CONSTRAIN=1024 MODEL="utransformer" set -x function train { python3 train.py \ --task shared-multilingual-translation \ --share-all-embeddings \ --num-workers 0 \ --arch $MODEL \ --max-tokens $CONSTRAIN --lr 1e-4 --min-lr 1e-9 \ --optimizer sgd \ --save-dir $CHECKPOINTS \ --log-format simple --log-interval 200 \ --criterion uc_loss \ --keep-interval-updates 5 \ --save-interval-updates 1000 \ --dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.1 \ --ddp-backend no_c10d \ --update-freq 4 \ --reset-optimizer \ --max-source-positions $CONSTRAIN \ --max-target-positions $CONSTRAIN \ $CONFIG & wait } function preprocess { python3 preprocess_cvit.py \ $CONFIG #--rebuild \ } function _test { OUTFILE=ufal-gen-tbig.out python3 generate.py config.yaml \ --task shared-multilingual-translation \ --path $CHECKPOINTS/checkpoint_last.pt > $OUTFILE cat $OUTFILE \ | grep "^H" | sed 's/^H-//g' | sort -n | cut -f 3 \ | sed 's/ //g' | sed 's/▁/ /g' | sed 's/^ //g' \ > ufal-test.hyp cat $OUTFILE \ | grep "^T" | sed 's/^T-//g' | sort -n | cut -f 2 \ | sed 's/ //g' | sed 's/▁/ /g' | sed 's/^ //g' \ > ufal-test.ref split -d -l 2000 ufal-test.hyp hyp.ufal. split -d -l 2000 ufal-test.ref ref.ufal. # perl multi-bleu.perl ref.ufal.00 < hyp.ufal.00 # perl multi-bleu.perl ref.ufal.01 < hyp.ufal.01 python3 -m indicnlp.contrib.wat.evaluate \ --reference ref.ufal.00 --hypothesis hyp.ufal.00 python3 -m indicnlp.contrib.wat.evaluate \ --reference ref.ufal.01 --hypothesis hyp.ufal.01 } function _export { ssh $USER@ada "mkdir -p $REMOTE_DIR/" rsync -rz $CHECKPOINTS/checkpoint_{best,last}.pt $REMOTE_DIR/ } trap "_export" SIGHUP copy preprocess train _export # ARG=$1 # eval "$1"