Skip to content

Commit

Permalink
feat: make it easier to download the datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
dimakis committed Nov 14, 2023
1 parent 7124cd7 commit 444ce4b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
11 changes: 9 additions & 2 deletions demo-notebooks/guided-demos/2_basic_jobs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,13 @@
"id": "83d77b74",
"metadata": {},
"source": [
"First, let's begin by submitting to Ray, training a basic NN on the MNIST dataset:"
"First, let's begin by submitting to Ray, training a basic NN on the MNIST dataset:\n",
"\n",
"NOTE: To test this demo in an air-gapped/ disconnected environment alter the training script to use a local dataset.\n",
"First we must download the MNIST dataset. We've included a helper script to do this for you. \n",
"\n",
"You can run the python script (`python download_mnist_datasets.py`) directly and then place the dataset in the same directory as this notebook. \n",
"The path to the dataset would be: `..guided-demos/MNIST/raw/` "
]
},
{
Expand All @@ -129,6 +135,7 @@
"jobdef = DDPJobDefinition(\n",
" name=\"mnisttest\",\n",
" script=\"mnist.py\",\n",
" # script=\"mnist_disconnected.py\", # training script for disconnected environment\n",
" scheduler_args={\"requirements\": \"requirements.txt\"}\n",
")\n",
"job = jobdef.submit(cluster)"
Expand Down Expand Up @@ -302,7 +309,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.9.18"
},
"vscode": {
"interpreter": {
Expand Down
2 changes: 1 addition & 1 deletion demo-notebooks/guided-demos/mnist_disconnected.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
# %%

local_minst_path = os.path.dirname(os.path.abspath(__file__) + "/MNIST/raw")
local_minst_path = os.path.dirname(os.path.abspath(__file__))

print("prior to running the trainer")
print("MASTER_ADDR: is ", os.getenv("MASTER_ADDR"))
Expand Down

0 comments on commit 444ce4b

Please sign in to comment.