-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training Script stablity 3B and 7B #72
Comments
Are you able to train 7B using dual RTX3090's? Do you think you could setup a notebook on Colab? Thank you!!!!!! |
Yes I can, the script provided does it all, If you wish to train on collab then a single 24GB RAM GPU would be required. |
i'm using V100 on aws. 2 GPU 8 CPU 96GB Ram, and its failing on oom. |
Send me the output you are getting on python consiole
…On Fri, May 19, 2023 at 4:15 AM snirbenyosef ***@***.***> wrote:
i'm using V100 on aws. 2 GPU 8 CPU 96GB Ram, and its failing on oom.
and idea why?
—
Reply to this email directly, view it on GitHub
<#72 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AJA2ECORXIUDFMIZGOEUWMTXGZRLPANCNFSM6AAAAAAXT2ZR7E>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
--
Kind Regards
Aamir Mirza
|
You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the @aamir-gmail i got 2 GPUS, 8 CPU, 96GB Ram. |
Did you follow the instructions, I can see that you are using Python 3.8,
(instead of 3.9) did you build deep speed from sources, with CPU Adam
optimizer support,
reduce your batch size to 1 and try from there, Follow the instruction like
mentioned and then come back to me.
On Sat, May 20, 2023 at 6:51 PM snirbenyosef ***@***.***>
wrote:
… Send me the output you are getting on python consiole
… <#m_8640391334241798899_>
On Fri, May 19, 2023 at 4:15 AM snirbenyosef *@*.*> wrote: i'm using V100
on aws. 2 GPU 8 CPU 96GB Ram, and its failing on oom. and idea why? — Reply
to this email directly, view it on GitHub <#72 (comment)
<#72 (comment)>>,
or unsubscribe
https://github.com/notifications/unsubscribe-auth/AJA2ECORXIUDFMIZGOEUWMTXGZRLPANCNFSM6AAAAAAXT2ZR7E
<https://github.com/notifications/unsubscribe-auth/AJA2ECORXIUDFMIZGOEUWMTXGZRLPANCNFSM6AAAAAAXT2ZR7E>
. You are receiving this because you authored the thread.Message ID: @.*>
-- Kind Regards Aamir Mirza
You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a
fast tokenizer, using the __call__ method is faster than using a method
to encode the text followed by a call to the pad method to get a padded
encoding.
Traceback (most recent call last):
File "train_gptNX_3B_v3.py", line 163, in
trainer.train()
File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py",
line 1664, in train
return inner_training_loop(
File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py",
line 1940, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py",
line 2751, in training_step
loss = self.deepspeed.backward(loss)
File "/opt/conda/lib/python3.8/site-packages/deepspeed/utils/nvtx.py",
line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py",
line 1851, in backward
self.optimizer.backward(loss, retain_graph=retain_graph)
File
"/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py",
line 1884, in backward
buf_0 = torch.empty(int(self.reduce_bucket_size),
RuntimeError: CUDA out of memory. Tried to allocate 764.00 MiB (GPU 0;
15.77 GiB total capacity; 15.23 GiB already allocated; 54.88 MiB free;
15.25 GiB reserved in total by PyTorch) If reserved memory is >> allocated
memory try setting max_split_size_mb to avoid fragmentation. See
documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
0%| | 0/1248 [00:00<?, ?it/s]
[2023-05-20 08:50:35,224] [INFO] [launch.py:428:sigkill_handler] Killing
subprocess 1348
[2023-05-20 08:50:35,718] [INFO] [launch.py:428:sigkill_handler] Killing
subprocess 1349
[2023-05-20 08:50:35,719] [ERROR] [launch.py:434:sigkill_handler]
['/opt/conda/bin/python3.8', '-u', 'train_gptNX_3B_v3.py',
'--local_rank=1', '--num_gpus=2'] exits with return code = 1
@aamir-gmail <https://github.com/aamir-gmail>
what am i doing wrong?
i got 2 GPUS, 8 CPU, 96GB Ram.
—
Reply to this email directly, view it on GitHub
<#72 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AJA2ECLWITM7YBNQS3GIWSTXHCA2TANCNFSM6AAAAAAXT2ZR7E>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
--
Kind Regards
Aamir Mirza
|
`# Developed by Aamir Mirza
create a conda virtual environment python 3.9
install PyTorch 1.13.1 ( not 2.0)
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia
install the latest transformers
conda install -c conda-forge transformers
install deepspeed from GitHub not pip install
build deepspeed with CPU Adam optimiser support like this
git clone https://github.com/microsoft/DeepSpeed
DS_BUILD_CPU_ADAM=1 pip install .
accelerate via pip
pip install Ninja
conda install -c conda-forge mpi4py
train via commandline for example
deepspeed train_gptNX_v2.py --num_gpus=2
In my case I have 2x 3090 24GB
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast, TextDataset,
DefaultDataCollator, DataCollatorForLanguageModeling, DataCollatorWithPadding
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
import os
os.environ['OMPI_MCA_opal_cuda_support'] = 'true'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
If you got a single GPU then change this to one
os.environ["WORLD_SIZE"] = "2"
Change this to your requirement for example 4096 (MAX)
MAX_LEN = 1024
stage2_config = """{
"bf16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
} """
class CustomTrainer(Trainer):
def compute_loss(self, model_a, inputs_a, return_outputs=False):
strd = ' '
outputs = model_a(**inputs_a, labels=inputs_a["input_ids"])
loss = outputs.loss
return (loss, outputs) if return_outputs else loss
tokenizer = GPTNeoXTokenizerFast.from_pretrained("stabilityai/stablelm-base-alpha-3b")
def process_data(examples):
texts = examples["text"]
# Remove empty lines
texts = [text for text in texts if len(text) > 0 and not text.isspace()]
# Remove lines that are too long
texts = [text for text in texts if len(text) < 512]
# Remove lines that are too short
texts = [text for text in texts if len(text) > 16]
# add newline character
texts = [text + ' ' + '\n' for text in texts]
examples["text"] = texts
return examples
process dataset columns [text] use tokenizer to get input_ids and attention mask
def process_data_add_mask(examples):
text = examples['text']
tokenizer.pad_token = tokenizer.eos_token
# Tokenize text
encoded_dict = tokenizer(
text,
padding=True,
truncation=True,
max_length=MAX_LEN
)
# Add input_ids and attention_mask to example
examples['input_ids'] = encoded_dict['input_ids']
examples['attention_mask'] = encoded_dict['attention_mask']
return examples
imdb_dataset = load_dataset('imdb')
imdb_dataset_train = imdb_dataset['train']
imdb_dataset_train = imdb_dataset_train.shuffle()
imdb_dataset_train = imdb_dataset_train.map(process_data, batched=True, remove_columns=['label'])
imdb_dataset_val = imdb_dataset['test']
imdb_dataset_val = imdb_dataset_val.shuffle()
imdb_dataset_val = imdb_dataset_val.map(process_data, batched=True, remove_columns=['label'])
train_dataset = imdb_dataset_train.map(process_data_add_mask, remove_columns=["text"], batched=True)
val_dataset = imdb_dataset_val.map(process_data_add_mask, remove_columns=["text"], batched=True)
strs = " "
model = GPTNeoXForCausalLM.from_pretrained("stabilityai/stablelm-base-alpha-3b")
absolute path required for deepspeed config
you can use the JSON above to create your own config
z_optimiser = '/two-tb/train_GPTNX/zeromq_config/stablelm-base-alpha-3b_config.json'
data_collator = DataCollatorWithPadding(tokenizer=tokenizer,
return_tensors="pt")
training_args_v2 = TrainingArguments(
output_dir="./trained_model",
learning_rate=2e-5,
save_total_limit=2,
fp16=True,
per_device_train_batch_size=1,
per_device_eval_batch_size=12,
evaluation_strategy="epoch",
deepspeed=z_optimiser,
num_train_epochs=1
)
Set up the trainer
trainer = CustomTrainer(
model=model,
args=training_args_v2,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_model()
`
The text was updated successfully, but these errors were encountered: