diff --git a/tools/scripts/dist_train.sh b/tools/scripts/dist_train.sh new file mode 100644 index 0000000..5e8c59a --- /dev/null +++ b/tools/scripts/dist_train.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -x +NGPUS=$1 +PY_ARGS=${@:2} + +while true +do + PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 )) + status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)" + if [ "${status}" != "0" ]; then + break; + fi +done +echo $PORT + +python -m torch.distributed.launch --nproc_per_node=${NGPUS} --rdzv_endpoint=localhost:${PORT} train.py --launcher pytorch ${PY_ARGS} +