Source code for yamle.data.text

import argparse
from yamle.data.datamodule import BaseDataModule
from typing import Any, Union, Tuple

from pytorch_lightning import LightningModule
import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import requests
import os
import torch
from torch.utils.data import random_split
from yamle.defaults import (
    TEXT_CLASSIFICATION_KEY,
    MEAN_PREDICTION_KEY,
    TRAIN_KEY,
    VALIDATION_KEY,
    TEST_KEY,
    INPUT_KEY,
    TARGET_KEY,
)


[docs] class TorchtextClassificationDataModule(BaseDataModule): """Data module for the torchvision datasets. Args: dataset (str): Name of the torchvision dataset. Currently supported are `wiki_text_2`, `wiki_text_103`, `imdb`. validation_portion (float): Portion of the training data to use for validation. seed (int): Seed for the random number generator. data_dir (str): Path to the data directory. """ mean = None std = None task = TEXT_CLASSIFICATION_KEY inputs_dim = None # This will be the sequence length inputs_dtype = torch.long outputs_dim = None # This will be the size of the vocabulary outputs_dtype = torch.long def __init__(self, dataset: str, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) if dataset not in ["wiki_text_2", "wiki_text_103", "imdb", "shakespeare"]: raise ValueError("Dataset not supported.") self._dataset = dataset self._sequence_length = self.inputs_dim[0] self._vocab: torchtext.vocab.Vocab = None def _process_data(self, data: torch.utils.data.Dataset) -> torch.utils.data.Dataset: """This method is used to tokenize, build the vocabulary and split the sequences into input and target.""" if self._dataset == "wiki_text_2": specials = ["<unk>"] elif self._dataset == "wiki_text_103": specials = ["<unk>"] elif self._dataset == "imdb": specials = ["<unk>"] elif self._dataset == "shakespeare": specials = ["<unk>"] tokenizer = get_tokenizer("basic_english") if self._vocab is None: self._vocab = build_vocab_from_iterator( map(tokenizer, data), specials=specials ) self._vocab.set_default_index(self._vocab["<unk>"]) data = [ torch.tensor(self._vocab(tokenizer(item)), dtype=torch.long) for item in data ] data = torch.cat(tuple(filter(lambda t: t.numel() > 0, data))) inputs = [] targets = [] for i in range(0, data.size(0) - 1, self._sequence_length): seq_length = min(self._sequence_length, data.size(0) - 1 - i) if seq_length != self._sequence_length: continue inputs.append(data[i : i + self._sequence_length]) targets.append(data[i + 1 : i + 1 + self._sequence_length]) inputs = torch.stack(inputs, dim=0) targets = torch.stack(targets, dim=0) return torch.utils.data.TensorDataset(inputs, targets)
[docs] def prepare_data(self) -> None: """Download and prepare the data, the data is stored in `self._train_dataset`, `self._validation_dataset` and `self._test_dataset`.""" super().prepare_data() if self._dataset == "wiki_text_2": self._train_dataset = self._process_data( torchtext.datasets.WikiText2(root=self._data_dir, split="train") ) self._validation_dataset = self._process_data( torchtext.datasets.WikiText2(root=self._data_dir, split="valid") ) self._test_dataset = self._process_data( torchtext.datasets.WikiText2(root=self._data_dir, split="test") ) elif self._dataset == "wiki_text_103": self._train_dataset = self._process_data( torchtext.datasets.WikiText103(root=self._data_dir, split="train") ) self._validation_dataset = self._process_data( torchtext.datasets.WikiText103(root=self._data_dir, split="valid") ) self._test_dataset = self._process_data( torchtext.datasets.WikiText103(root=self._data_dir, split="test") ) elif self._dataset == "imdb": self._train_dataset = self._process_data( torchtext.datasets.IMDB(root=self._data_dir, split="train") ) self._test_dataset = self._process_data( torchtext.datasets.IMDB(root=self._data_dir, split="test") ) elif self._dataset == "shakespeare": data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" with open(os.path.join(self._data_dir, "shakespeare.txt"), "w") as f: f.write(requests.get(data_url).text) with open(os.path.join(self._data_dir, "shakespeare.txt"), "r") as f: dataset = self._process_data(f.read().strip().split()) n = len(dataset) train_size = int(n * (1 - self._test_portion)) test_size = n - train_size self._train_dataset, self._test_dataset = random_split( dataset, [train_size, test_size], generator=torch.Generator().manual_seed(self._seed), ) else: raise ValueError("Dataset not supported.")
@torch.no_grad() def _get_prediction( self, tester: LightningModule, x: torch.Tensor, y: Union[torch.Tensor, int], phase: str = TRAIN_KEY, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: super()._get_prediction(tester, x, y, phase) x = x.to(tester.device) y = y.to(tester.device) if phase == TRAIN_KEY: output = tester.training_step([x, y], batch_idx=0) elif phase == VALIDATION_KEY: output = tester.validation_step([x, y], batch_idx=0) elif phase == TEST_KEY: output = tester.test_step([x, y], batch_idx=0) y_hat = output[MEAN_PREDICTION_KEY] x = output[INPUT_KEY] y = output[TARGET_KEY] y_hat = torch.argmax(y_hat, dim=2) return y_hat, x, y
[docs] @torch.no_grad() def plot( self, tester: LightningModule, save_path: str, specific_name: str = "" ) -> None: """Sample random text sequences from the test set and plot them.""" # Sample random text sequences from the test set train_dataloader = self.train_dataloader() inputs, targets = next(iter(train_dataloader)) outputs = self._get_prediction(tester, inputs, targets, TEST_KEY)[0] inputs = inputs.cpu().numpy() outputs = outputs.cpu().numpy() targets = targets.cpu().numpy() for i in range(inputs.shape[0]): input = [self._vocab.lookup_token(t) for t in inputs[i]] output = [self._vocab.lookup_token(t) for t in outputs[i]] target = [self._vocab.lookup_token(t) for t in targets[i]] # Write the text sequences to a file with open( os.path.join(save_path, f"predictions_{specific_name}.txt"), "a" ) as f: f.write("Input: " + " ".join(input)) f.write("\n") f.write("Output: " + " ".join(output)) f.write("\n") f.write("Target: " + " ".join(target)) f.write("\n")
[docs] class TorchtextClassificationModelWikiText2(TorchtextClassificationDataModule): """Data module for the WikiText2 dataset.""" inputs_dim = (20,) outputs_dim = 28782 targets_dim = 20 def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(dataset="wiki_text_2", *args, **kwargs)
[docs] class TorchtextClassificationModelWikiText103(TorchtextClassificationDataModule): """Data module for the WikiText103 dataset.""" inputs_dim = (20,) outputs_dim = 28782 targets_dim = 20 def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(dataset="wiki_text_103", *args, **kwargs)
[docs] class TorchtextClassificationModelIMDB(TorchtextClassificationDataModule): """Data module for the IMDB dataset.""" inputs_dim = (20,) outputs_dim = 28782 targets_dim = 20 def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(dataset="imdb", *args, **kwargs)
[docs] class Shakespeare(TorchtextClassificationDataModule): """Data module for the Shakespeare dataset.""" inputs_dim = (20,) outputs_dim = 28782 targets_dim = 20 def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(dataset="shakespeare", *args, **kwargs)