-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrainer.sh
149 lines (125 loc) · 5.2 KB
/
trainer.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/bin/bash
ARGPARSE_DESCRIPTION="Trainer utility"
source $(dirname $0)/argparse.bash || exit 1
argparse "$@" <<EOF || exit 1
parser.add_argument('--ngpus', default=8, type=int,
help='No. of gpus to use')
parser.add_argument('--training_type', type=str, choices=["m2m", "m2o", "o2m"],
required=True, help='Training type (many-to-many/many-to-one/one-to-many)')
parser.add_argument('--pivot_lang', type=str, default="english",
help='Pivot language (Applicable for many-to-one and one-to-many)')
parser.add_argument('--sampling', type=str, default="multistage", choices=["multistage", "unistage"],
help='Sampling type (Applicable for many-to-many)')
parser.add_argument('--minibatching', type=str, default="", choices=["fixed_src", "fixed_tgt", "ignored"],
help='Minibatching (Applicable for many-to-many with multistage sampling)')
parser.add_argument('--exclude_native', action='store_true',
default=False, help='Exclude the native-to-native filepairs during training')
parser.add_argument('--per_lang_batch_size', default=32, type=int,
help='Effective batch size per language')
parser.add_argument('--copy_last_checkpoint', type=str, default="",
help='If provided, copies the last checkpoint to this directory')
EOF
export BASE_DIR=$(realpath .)
export ROOT_DATASET_DIR="${BASE_DIR}/dataset"
export ROOT_INPUT_DIR="${BASE_DIR}/input"
export ROOT_OUTPUT_DIR="${BASE_DIR}/output"
export PREFIX="${TRAINING_TYPE}_${PIVOT_LANG}_${PER_LANG_BATCH_SIZE}"
if [[ "$TRAINING_TYPE" = "m2m" ]]; then
PREFIX="${TRAINING_TYPE}_${SAMPLING}_${PER_LANG_BATCH_SIZE}"
OPTIONAL_ARGS=(
"--multistage_upsampling_factors 0.5 0.75"
)
if [[ "$SAMPLING" = "unistage" ]]; then
OPTIONAL_ARGS=(
"--upsampling_factor 0.25"
)
fi
if [[ "$MINIBATCHING" != "" ]]; then
PREFIX="${TRAINING_TYPE}_${SAMPLING}_${MINIBATCHING}_${PER_LANG_BATCH_SIZE}"
OPTIONAL_ARGS+=(
"--minibatching $MINIBATCHING"
)
fi
else
OPTIONAL_ARGS=(
"--upsampling_factor 0.75"
)
fi
export SUFFIX="with_native"
if [[ "$EXCLUDE_NATIVE" = "yes" ]]; then
SUFFIX="without_native"
fi
export BASENAME="${PREFIX}_${SUFFIX}"
export INPUT_DIR="${ROOT_INPUT_DIR}/${BASENAME}"
export OUTPUT_DIR="${ROOT_OUTPUT_DIR}/${BASENAME}"
export MIN_EXAMPLE_COUNT=30
conda activate "${BASE_DIR}/env" || source activate "${BASE_DIR}/env"
if [[ "${SLURM_PROCID:-0}" -eq 0 && "${SLURM_LOCALID:-0}" -eq 0 ]]; then
mkdir -p $OUTPUT_DIR
python "${BASE_DIR}/generate_data.py" \
--dataset_dir $ROOT_DATASET_DIR \
--output_dir $INPUT_DIR \
--training_type $TRAINING_TYPE \
--pivot_lang $PIVOT_LANG \
--exclude_native $EXCLUDE_NATIVE \
--min_example_count $MIN_EXAMPLE_COUNT
fi
# for ozstar only; the model must
# be cached if this variable is set
export LINK_CACHE_ONLY=false
# training settings
export max_steps=25000
export save_steps=25000
export logging_steps=100
# validation settings
export evaluation_strategy="no"
# model settings
export model_name="google/mt5-base"
# optimization settings
export learning_rate=1
export warmup_steps=5000
export gradient_accumulation_steps=4
export weight_decay=0.01
export lr_scheduler_type="transformer"
export label_smoothing_factor=0.1
# misc. settings
export seed=1234
# input / output settings
export input_dir=$INPUT_DIR
export output_dir=$OUTPUT_DIR
# batch / sequence sizes
export PER_DEVICE_TRAIN_BATCH_SIZE=8
export MAX_SOURCE_LENGTH=512
export MAX_TARGET_LENGTH=84
# cross lingual settings
export per_lang_batch_size=$PER_LANG_BATCH_SIZE
# logging settings
export WANDB_PROJECT="Crossum"
export WANDB_WATCH=false
python -m torch.distributed.launch \
--nproc_per_node=${NPROC_PER_NODE:-$NGPUS} \
--nnodes=${SLURM_JOB_NUM_NODES:-1} \
--node_rank=${SLURM_PROCID:-0} \
--master_addr="${PARENT:-127.0.0.1}" --master_port="${MPORT:-29500}" "${BASE_DIR}/pipeline.py" \
--model_name_or_path $model_name \
--data_dir $INPUT_DIR --output_dir $OUTPUT_DIR \
--learning_rate=$learning_rate --warmup_steps $warmup_steps --gradient_accumulation_steps $gradient_accumulation_steps \
--weight_decay $weight_decay --lr_scheduler_type $lr_scheduler_type --adafactor --label_smoothing_factor $label_smoothing_factor \
--per_device_train_batch_size=$PER_DEVICE_TRAIN_BATCH_SIZE --logging_steps $logging_steps \
--max_source_length $MAX_SOURCE_LENGTH --max_target_length $MAX_TARGET_LENGTH \
--per_lang_batch_size $per_lang_batch_size \
--seed $seed --overwrite_output_dir \
--max_steps $max_steps --save_steps $save_steps \
--evaluation_strategy $evaluation_strategy \
--logging_first_step \
--cache_dir "${BASE_DIR}/cache_dir" \
--run_name $BASENAME \
--use_langid \
--langid_map_path "${BASE_DIR}/debug/extra_tokens_langid_map.json" \
--reinitialize_langid_embeddings "bos" \
--do_train \
$(echo -n ${OPTIONAL_ARGS[@]}) |& tee "${OUTPUT_DIR}/run.log"
if [[ "$COPY_LAST_CHECKPOINT" != "" ]]; then
mkdir -p "$COPY_LAST_CHECKPOINT"
cp -r "${OUTPUT_DIR}/checkpoint-${max_steps}" "${COPY_LAST_CHECKPOINT}/${BASENAME}"
fi