I made a fully custom made GPT in Jax (with Keras 3), using Tensorflow for the data pipeline.
I've trained the model on the Shakespeare dataset and got good results (so no problem with the model). Now I want to train it on the Tiny-Stories dataset which is pretty big with GPT of 15M parameters.
Here is the code for loading the data:
def get_dataset_lists(ds_path:str):dataset = open(ds_path, "r", encoding="utf-8").read() # [...]dataset = dataset.split("<|endoftext|>")r.shuffle(dataset)dataset:list = spm.Encode( # llama's sentence piece encodertf.strings.strip(dataset).numpy().tolist(), add_bos=True,add_eos=False) # [[SOS story], ..., [SOS story]]print("\tNumber of stories:", len(dataset))return datasetdef tf_dataload(dataset:list,batch_size:int,maxlen:int,shift:int,
):import functools; import operatordataset = functools.reduce(operator.iconcat, dataset, [])num_tokens = len(dataset); print("\tNumber of tokens in the dataset is", num_tokens)unique_tok = set(dataset); print("\tNumber of unique tokens in the dataset is", len(unique_tok))# [SOS story ... SOS story]dataset = tf.data.Dataset.from_tensor_slices(dataset)dataset = dataset.window(maxlen+1, shift=shift, drop_remainder=True)# [[...], [...], [...], ...] shape(m, maxlen+1)dataset = dataset.flat_map(lambda window: window.batch(maxlen+1))dataset = dataset.shuffle(10_000*batch_size, reshuffle_each_iteration=reshuffle_each_iteration)# [ [ [...], [...], [...], ...], ...] shape(m//B, B, maxlen+1)dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE)dataset = dataset.shuffle(batch_size*100)dataset = dataset.map(lambda window: (window[:, :-1], window[:, 1:]), num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)return dataset # (shape(m//B, B, maxlen) shape(m//B, B, maxlen))def load_data(train_ds_path:str,val_ds_path:str,batch_size:int,maxlen:int,shift:int,
): print("Training Dataset:")train_ds = tf_dataload(get_dataset_lists(train_ds_path), batch_size, maxlen, shift, reshuffle_each_iteration=True)print("Validation Dataset:")val_ds = tf_dataload(get_dataset_lists(val_ds_path), batch_size, maxlen, shift, reshuffle_each_iteration=True)print(f"\n{train_ds}\n{val_ds}")datasets = {"train": train_ds.repeat(), "val":val_ds}return datasets
- I've certain questions regarding
the value of the
shift
?
First I set it equal to 1, but the training was very slow, even after 100000 steps it didn't converge even though it was decreasing slowly (I think there's no problem with the learning rate as I plotted Loss Vs Lr and selected the max learning rate possible and used cosine decay with warmup)
So I looked into Karpathy's llama-2 repo and the shift was equal to maxlen.
So I set it equal to maxlen
and trained it for 100000 steps but the model is learning very slowly, and didn't get a loss even close to what Karpathy got
(I don't know what's the problem, as I've closely followed Karpathy's llama2 repo)
What is shift generally equal to when pre-training an LLM on Language Modelling?
Shouldn't it be 1, because the transformer model is not positionally invariant, and it would affect model performance if shift
is not equal to 1? But then the number of samples will be very large...?
- And for what number of steps to train a LLM given the number of tokens
You may find the below helpful...
@dataclass
class GPTArgs:"""GPT Configuration"""d_model:int = 288num_layers:int = 6num_heads:int = 6max_context_length:int = 256vocab_size:int = VOCAB_SIZE # 32Koutput_units:int = None # equal to vocab_size if None in model init assert d_model % 2 == 0assert d_model % num_heads == 0dropout_rate:float = 0.1@dataclass
class TArgs:# lr schedulerinit_lr:float = 1e-7max_lr:float = 6.5e-4min_lr:float = 0.1*max_lr # The factor is usually 0.1 or 0.0num_steps:int = 100_000warmup_steps:int = 1000 # 1000, to make training more stable instead of 2000decay_steps:int = num_steps# optimizerbeta1:float = 0.9beta2:float = 0.95weight_decay:float = 1e-1clipvalue:float = 1e0num_grad_accumalation_steps:int = 4# num_tok_per_update = batch_size * maxlen * gradient_accumalation = 128 * 256 * 4 = 131_072# trainingcheckpoint:str = 'weights/GPTstories/Epoch{epoch}.weights.h5'train_ds_path:str = "TinyStoriesDataset/TinyStories-train.txt"val_ds_path:str = "TinyStoriesDataset/TinyStories-valid.txt"steps_per_epoch = eval_freq = 2000eval_steps:int = 200batch_size:int = 128 patience:int = 10 # early stopping with restore best weights
Update 1:
I thought that the model wasn't getting the training samples uniformly so I modified the data pipeline and also increased the number of steps to 200,000. But there were no significant improvements. The training was still very slow by the end and loss was decreasing by 0.01 every epoch (of 2000 steps)... Got a loss of 1.67 on validation set
def pretokenize_and_save_dataset(dataset_path:str, num_shards:int, shard_dir:str):dataset = open(dataset_path, "r", encoding="utf-8").read() # [...]dataset = dataset.split("<|endoftext|>")r.shuffle(dataset)dataset:list = spm.Encode(tf.strings.strip(dataset).numpy().tolist(), add_bos=True,add_eos=False) # [[SOS story], ..., [SOS story]]print("Dataset:")print("\tNumber of stories:", len(dataset))# flattendataset = functools.reduce(operator.iconcat, dataset, [])num_tokens = len(dataset); print("\tNumber of tokens in the dataset:", num_tokens)print("\tNumber of unique tokens in the dataset:", len(set(dataset)))dataset = np.asarray(dataset, dtype=np.uint16) # [SOS story ... SOS story]print("\tAvg length of story:", num_tokens/((dataset==1).sum()))# shard and save datasetsharded_datasets_list = np.array_split(dataset, num_shards) # [[SOS story...], [...], [...], ...]filenames = [os.path.join(shard_dir, f"shard{i+1}.npy") for i in range(num_shards)]for filename, sharded_ds in zip(filenames, sharded_datasets_list):with open(filename, "wb") as f:np.save(f, sharded_ds)return filenamesdef load_data_as_tfds(dataset:np.ndarray,maxlen:int,shift:int,
):# [SOS story ... SOS story]dataset = tf.data.Dataset.from_tensor_slices(dataset.tolist())dataset = dataset.window(maxlen+1, shift=shift, drop_remainder=True)# [[...], [...], [...], ...] shape(m, maxlen+1)dataset = dataset.flat_map(lambda window: window.batch(maxlen+1))dataset = dataset.shuffle(10_000*128)return datasetdef batch_tfds(dataset:tf.data.Dataset,batch_size:int,
):dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE)dataset = dataset.shuffle(batch_size*1000)dataset = dataset.map(lambda window: (window[:, :-1], window[:, 1:]), num_parallel_calls=tf.data.AUTOTUNE)dataset = dataset.repeat().prefetch(tf.data.AUTOTUNE)return datasetdef load_data(dataset_path:str,batch_size:int,maxlen:int,shift:int,num_shards:int,shard_dir:str
): if os.path.exists(shard_dir) and os.listdir(shard_dir):filenames = glob.glob(os.path.join(shard_dir, "*.npy"))else:os.makedirs(shard_dir)filenames = pretokenize_and_save_dataset(dataset_path, num_shards=num_shards, shard_dir=shard_dir)r.shuffle(filenames)to_tfds = lambda dataset: load_data_as_tfds(dataset, maxlen=maxlen, shift=shift)num_train_shards = round(0.9651*num_shards)num_val_shards = num_shards-num_train_shardsprint("Training Dataset:")print(f"\tNumber of files taken for training: {num_train_shards}/{num_shards}")train_datasets_lists = [to_tfds(np.load(filename)) for filename in filenames[:num_train_shards]]train_ds = tf.data.Dataset.sample_from_datasets(train_datasets_lists, weights=[1/num_train_shards]*num_train_shards)# [ [ [...], [...], [...], ...], ...] shape(m//B, B, maxlen+1)train_ds = batch_tfds(train_ds, batch_size=batch_size)print("Validation Dataset:")print(f"\tNumber of files taken for validation: {num_val_shards}/{num_shards}")val_datasets_lists = [to_tfds(np.load(filename)) for filename in filenames[num_train_shards:]]val_ds = tf.data.Dataset.sample_from_datasets(val_datasets_lists, weights=[1/num_val_shards]*num_val_shards)# [ [ [...], [...], [...], ...], ...] shape(m//B, B, maxlen+1)val_ds = batch_tfds(val_ds, batch_size=batch_size)print(f"\n{train_ds}\n{val_ds}")datasets = {"train": train_ds, "val":val_ds}return datasets
Update 2
- Gradient accumulation results in slow training, so changed it to 1 and trained it for 200,000 epochs, got a validation loss of 1.60.