From 71b7cf6dd78bc61b927afa4333dc96dc4dc1647d Mon Sep 17 00:00:00 2001 From: "Lucas Lopes Oliveira (lopesoll)" Date: Fri, 10 Mar 2023 16:45:55 +0000 Subject: [PATCH] Add files via upload --- Seq2Seq translation model.ipynb | 844 ++++++++++++++++++++++++++++++++ 1 file changed, 844 insertions(+) create mode 100644 Seq2Seq translation model.ipynb diff --git a/Seq2Seq translation model.ipynb b/Seq2Seq translation model.ipynb new file mode 100644 index 0000000..b236523 --- /dev/null +++ b/Seq2Seq translation model.ipynb @@ -0,0 +1,844 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 21, + "id": "63a0d16f", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from tensorflow.keras.layers import Input, LSTM, Dense, RepeatVector, Embedding\n", + "from tensorflow.keras import Model\n", + "from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping\n", + "from tensorflow.keras.models import load_model\n", + "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", + "from tensorflow.keras.preprocessing.text import Tokenizer\n", + "from keras.utils.vis_utils import plot_model\n", + "import matplotlib.pyplot as plt\n", + "from nltk.translate.bleu_score import sentence_bleu, corpus_bleu\n", + "from nltk.tokenize import word_tokenize" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "96100a91", + "metadata": {}, + "outputs": [], + "source": [ + "df=pd.read_csv(\"eng-fr.csv\")\n", + "#df=df.sample(40000, random_state=42)\n", + "df=df[:40000]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3bf99a74", + "metadata": {}, + "outputs": [], + "source": [ + "import re\n", + "# remove unwanted characters\n", + "def remove_chars(text):\n", + " # pattern to match unwanted characters\n", + " pattern = r'[^\\w\\s]'\n", + " # replace unwanted characters with space\n", + " text = re.sub(pattern, ' ', text)\n", + " # remove multiple spaces\n", + " text = re.sub('\\s+', ' ', text).strip()\n", + " return text\n", + "\n", + "df = df.applymap(lambda s: s.lower())\n", + "df['English words/sentences'] = df['English words/sentences'].apply(remove_chars)\n", + "df['French words/sentences'] = df['French words/sentences'].apply(remove_chars)\n", + "eng, fr = df[\"English words/sentences\"], df[\"French words/sentences\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "f316ab31", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(7, 15)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#calculate the maximum length of a sentence (nr of words)\n", + "max_len_eng_sentence = [len(s.split()) for s in eng]\n", + "max_len_fr_sentence = [len(s.split()) for s in fr]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1e817b79", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9332\n", + "15\n", + "5156\n", + "7\n" + ] + } + ], + "source": [ + "#create french tokenizer\n", + "fr_tokenizer = Tokenizer()\n", + "fr_tokenizer.fit_on_texts(fr)\n", + "#get vocabulary size and maximum length of a sentence \n", + "fr_vocab_size = len(fr_tokenizer.word_index) + 1\n", + "fr_length = max(max_len_fr_sentence)\n", + "print(fr_vocab_size)\n", + "print(fr_length)\n", + "\n", + "#create english tokenizer\n", + "eng_tokenizer = Tokenizer()\n", + "eng_tokenizer.fit_on_texts(eng)\n", + "#get vocabulary size and maximum length of a sentence \n", + "eng_vocab_size = len(eng_tokenizer.word_index) + 1\n", + "eng_length = max(max_len_eng_sentence)\n", + "print(eng_vocab_size)\n", + "print(eng_length)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f5d31d1c", + "metadata": {}, + "outputs": [], + "source": [ + "#split data into train and test \n", + "from sklearn.model_selection import train_test_split\n", + "train_eng, test_eng, train_fr, test_fr = train_test_split(eng, fr, test_size=0.1, random_state=42)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "732478ae", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((4000, 15), (4000, 7))" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#fit the tokenizer and add padding to the training and testing data\n", + "train_fr_encoded = fr_tokenizer.texts_to_sequences(train_fr)\n", + "train_fr_encoded = pad_sequences(train_fr_encoded, maxlen=fr_length, padding='post')\n", + "\n", + "train_eng_encoded = eng_tokenizer.texts_to_sequences(train_eng)\n", + "train_eng_encoded = pad_sequences(train_eng_encoded, maxlen=eng_length, padding='post')\n", + "\n", + "test_fr_encoded = fr_tokenizer.texts_to_sequences(test_fr)\n", + "test_fr_encoded = pad_sequences(test_fr_encoded, maxlen=fr_length, padding='post')\n", + "\n", + "test_eng_encoded = eng_tokenizer.texts_to_sequences(test_eng)\n", + "test_eng_encoded = pad_sequences(test_eng_encoded, maxlen=eng_length, padding='post')\n", + "\n", + "test_fr_encoded.shape, test_eng_encoded.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "88561285", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " Encoder_Input (InputLayer) [(None, 7)] 0 \n", + " \n", + " Encoder_embedding (Embeddin (None, 7, 256) 1319936 \n", + " g) \n", + " \n", + " Encoder_LSTM (LSTM) (None, 256) 525312 \n", + " \n", + " Reperat_Vector (RepeatVecto (None, 15, 256) 0 \n", + " r) \n", + " \n", + " Decoder_LSTM (LSTM) (None, 15, 256) 525312 \n", + " \n", + " Decoder_Softmax (Dense) (None, 15, 9332) 2398324 \n", + " \n", + "=================================================================\n", + "Total params: 4,768,884\n", + "Trainable params: 4,768,884\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "#create the model \n", + "# create the input layer for the encoder\n", + "encoder_inputs= Input(shape=(eng_length,), name='Encoder_Input')\n", + "# create an embedding layer for the encoder input\n", + "encoder_embedding=Embedding(eng_vocab_size, 256, name='Encoder_embedding')(encoder_inputs)\n", + "# create a LSTM layer with 256 units for the encoder\n", + "encoder_lstm=LSTM(256, name='Encoder_LSTM')(encoder_embedding)\n", + "# Repeat the encoder output fr_length times\n", + "repeat_vector=RepeatVector(fr_length, name='Repeat_Vector')(encoder_lstm)\n", + "\n", + "# Define a LSTM layer with 256 units for the decoder\n", + "decoder_lstm=LSTM(256, return_sequences=True, name='Decoder_LSTM')(repeat_vector)\n", + "# Define a Dense layer for the decoder output\n", + "decoder_softmax=Dense(fr_vocab_size, activation=\"softmax\", name='Decoder_Softmax')(decoder_lstm)\n", + "# Define the model with the encoder input as input and the decoder output as output\n", + "model= Model(inputs=encoder_inputs, outputs=decoder_softmax)\n", + "\n", + "# Compile the model \n", + "model.compile(optimizer=\"rmsprop\", loss=\"sparse_categorical_crossentropy\", metrics=['accuracy'])\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9c68204c", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#plot the model\n", + "from keras.utils import plot_model\n", + "plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "890fb658", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/100\n", + "64/64 [==============================] - 67s 986ms/step - loss: 3.3326 - accuracy: 0.6865 - val_loss: 2.1809 - val_accuracy: 0.7102\n", + "Epoch 2/100\n", + "64/64 [==============================] - 61s 954ms/step - loss: 2.1077 - accuracy: 0.7084 - val_loss: 2.0131 - val_accuracy: 0.7158\n", + "Epoch 3/100\n", + "64/64 [==============================] - 61s 958ms/step - loss: 1.9839 - accuracy: 0.7148 - val_loss: 1.9478 - val_accuracy: 0.7156\n", + "Epoch 4/100\n", + "64/64 [==============================] - 61s 956ms/step - loss: 1.9316 - accuracy: 0.7158 - val_loss: 1.9157 - val_accuracy: 0.7169\n", + "Epoch 5/100\n", + "64/64 [==============================] - 60s 936ms/step - loss: 1.9112 - accuracy: 0.7164 - val_loss: 1.9096 - val_accuracy: 0.7175\n", + "Epoch 6/100\n", + "64/64 [==============================] - 60s 931ms/step - loss: 1.8933 - accuracy: 0.7182 - val_loss: 1.9011 - val_accuracy: 0.7217\n", + "Epoch 7/100\n", + "64/64 [==============================] - 60s 933ms/step - loss: 1.8798 - accuracy: 0.7205 - val_loss: 1.8798 - val_accuracy: 0.7225\n", + "Epoch 8/100\n", + "64/64 [==============================] - 60s 935ms/step - loss: 1.8654 - accuracy: 0.7219 - val_loss: 1.9422 - val_accuracy: 0.7180\n", + "Epoch 9/100\n", + "64/64 [==============================] - 60s 933ms/step - loss: 1.8558 - accuracy: 0.7227 - val_loss: 1.8634 - val_accuracy: 0.7250\n", + "Epoch 10/100\n", + "64/64 [==============================] - 60s 932ms/step - loss: 1.8423 - accuracy: 0.7240 - val_loss: 1.8592 - val_accuracy: 0.7251\n", + "Epoch 11/100\n", + "64/64 [==============================] - 60s 930ms/step - loss: 1.8314 - accuracy: 0.7256 - val_loss: 1.8370 - val_accuracy: 0.7246\n", + "Epoch 12/100\n", + "64/64 [==============================] - 60s 931ms/step - loss: 1.8220 - accuracy: 0.7270 - val_loss: 1.8214 - val_accuracy: 0.7288\n", + "Epoch 13/100\n", + "64/64 [==============================] - 60s 933ms/step - loss: 1.8135 - accuracy: 0.7277 - val_loss: 1.8351 - val_accuracy: 0.7257\n", + "Epoch 14/100\n", + "64/64 [==============================] - 61s 954ms/step - loss: 1.8059 - accuracy: 0.7286 - val_loss: 1.8158 - val_accuracy: 0.7276\n", + "Epoch 15/100\n", + "64/64 [==============================] - 63s 984ms/step - loss: 1.7999 - accuracy: 0.7291 - val_loss: 1.8177 - val_accuracy: 0.7273\n", + "Epoch 16/100\n", + "64/64 [==============================] - 60s 935ms/step - loss: 1.7907 - accuracy: 0.7300 - val_loss: 1.8428 - val_accuracy: 0.7234\n", + "Epoch 17/100\n", + "64/64 [==============================] - 60s 938ms/step - loss: 1.7839 - accuracy: 0.7305 - val_loss: 1.8015 - val_accuracy: 0.7319\n", + "Epoch 18/100\n", + "64/64 [==============================] - 60s 935ms/step - loss: 1.7763 - accuracy: 0.7314 - val_loss: 1.7842 - val_accuracy: 0.7334\n", + "Epoch 19/100\n", + "64/64 [==============================] - 60s 937ms/step - loss: 1.7677 - accuracy: 0.7324 - val_loss: 1.7858 - val_accuracy: 0.7340\n", + "Epoch 20/100\n", + "64/64 [==============================] - 60s 934ms/step - loss: 1.7594 - accuracy: 0.7338 - val_loss: 1.8034 - val_accuracy: 0.7297\n", + "Epoch 21/100\n", + "64/64 [==============================] - 60s 939ms/step - loss: 1.7508 - accuracy: 0.7346 - val_loss: 1.7792 - val_accuracy: 0.7341\n", + "Epoch 22/100\n", + "64/64 [==============================] - 59s 929ms/step - loss: 1.7412 - accuracy: 0.7359 - val_loss: 1.7605 - val_accuracy: 0.7372\n", + "Epoch 23/100\n", + "64/64 [==============================] - 60s 932ms/step - loss: 1.7292 - accuracy: 0.7377 - val_loss: 1.7479 - val_accuracy: 0.7402\n", + "Epoch 24/100\n", + "64/64 [==============================] - 60s 933ms/step - loss: 1.7151 - accuracy: 0.7398 - val_loss: 1.7390 - val_accuracy: 0.7402\n", + "Epoch 25/100\n", + "64/64 [==============================] - 59s 930ms/step - loss: 1.7026 - accuracy: 0.7408 - val_loss: 1.7341 - val_accuracy: 0.7406\n", + "Epoch 26/100\n", + "64/64 [==============================] - 60s 931ms/step - loss: 1.6902 - accuracy: 0.7420 - val_loss: 1.7177 - val_accuracy: 0.7432\n", + "Epoch 27/100\n", + "64/64 [==============================] - 60s 932ms/step - loss: 1.6752 - accuracy: 0.7444 - val_loss: 1.7078 - val_accuracy: 0.7421\n", + "Epoch 28/100\n", + "64/64 [==============================] - 60s 930ms/step - loss: 1.6597 - accuracy: 0.7463 - val_loss: 1.6868 - val_accuracy: 0.7467\n", + "Epoch 29/100\n", + "64/64 [==============================] - 59s 928ms/step - loss: 1.6456 - accuracy: 0.7485 - val_loss: 1.6678 - val_accuracy: 0.7491\n", + "Epoch 30/100\n", + "64/64 [==============================] - 59s 929ms/step - loss: 1.6319 - accuracy: 0.7507 - val_loss: 1.6539 - val_accuracy: 0.7491\n", + "Epoch 31/100\n", + "64/64 [==============================] - 60s 934ms/step - loss: 1.6187 - accuracy: 0.7522 - val_loss: 1.6439 - val_accuracy: 0.7541\n", + "Epoch 32/100\n", + "64/64 [==============================] - 60s 933ms/step - loss: 1.6052 - accuracy: 0.7541 - val_loss: 1.6314 - val_accuracy: 0.7543\n", + "Epoch 33/100\n", + "64/64 [==============================] - 60s 934ms/step - loss: 1.5913 - accuracy: 0.7560 - val_loss: 1.6236 - val_accuracy: 0.7559\n", + "Epoch 34/100\n", + "64/64 [==============================] - 60s 942ms/step - loss: 1.5807 - accuracy: 0.7576 - val_loss: 1.6197 - val_accuracy: 0.7563\n", + "Epoch 35/100\n", + "64/64 [==============================] - 58s 912ms/step - loss: 1.5699 - accuracy: 0.7587 - val_loss: 1.6049 - val_accuracy: 0.7584\n", + "Epoch 36/100\n", + "64/64 [==============================] - 59s 926ms/step - loss: 1.5603 - accuracy: 0.7603 - val_loss: 1.6051 - val_accuracy: 0.7593\n", + "Epoch 37/100\n", + "64/64 [==============================] - 58s 903ms/step - loss: 1.5503 - accuracy: 0.7618 - val_loss: 1.6053 - val_accuracy: 0.7616\n", + "Epoch 38/100\n", + "64/64 [==============================] - 60s 934ms/step - loss: 1.5406 - accuracy: 0.7633 - val_loss: 1.5776 - val_accuracy: 0.7642\n", + "Epoch 39/100\n", + "64/64 [==============================] - 58s 902ms/step - loss: 1.5309 - accuracy: 0.7651 - val_loss: 1.5778 - val_accuracy: 0.7627\n", + "Epoch 40/100\n", + "64/64 [==============================] - 60s 932ms/step - loss: 1.5220 - accuracy: 0.7659 - val_loss: 1.5600 - val_accuracy: 0.7642\n", + "Epoch 41/100\n", + "64/64 [==============================] - 59s 915ms/step - loss: 1.5124 - accuracy: 0.7671 - val_loss: 1.5573 - val_accuracy: 0.7639\n", + "Epoch 42/100\n", + "64/64 [==============================] - 59s 929ms/step - loss: 1.5032 - accuracy: 0.7689 - val_loss: 1.5625 - val_accuracy: 0.7671\n", + "Epoch 43/100\n", + "64/64 [==============================] - 60s 932ms/step - loss: 1.4935 - accuracy: 0.7703 - val_loss: 1.5408 - val_accuracy: 0.7679\n", + "Epoch 44/100\n", + "64/64 [==============================] - 59s 923ms/step - loss: 1.4842 - accuracy: 0.7713 - val_loss: 1.5266 - val_accuracy: 0.7701\n", + "Epoch 45/100\n", + "64/64 [==============================] - 60s 937ms/step - loss: 1.4745 - accuracy: 0.7723 - val_loss: 1.5265 - val_accuracy: 0.7702\n", + "Epoch 46/100\n", + "64/64 [==============================] - 58s 912ms/step - loss: 1.4667 - accuracy: 0.7733 - val_loss: 1.5238 - val_accuracy: 0.7688\n", + "Epoch 47/100\n", + "64/64 [==============================] - 59s 924ms/step - loss: 1.4580 - accuracy: 0.7748 - val_loss: 1.5150 - val_accuracy: 0.7706\n", + "Epoch 48/100\n", + "64/64 [==============================] - 58s 910ms/step - loss: 1.4489 - accuracy: 0.7756 - val_loss: 1.5304 - val_accuracy: 0.7670\n", + "Epoch 49/100\n", + "64/64 [==============================] - 60s 940ms/step - loss: 1.4419 - accuracy: 0.7765 - val_loss: 1.5184 - val_accuracy: 0.7722\n", + "Epoch 50/100\n", + "64/64 [==============================] - 60s 932ms/step - loss: 1.4337 - accuracy: 0.7775 - val_loss: 1.5022 - val_accuracy: 0.7729\n", + "Epoch 51/100\n", + "64/64 [==============================] - 59s 925ms/step - loss: 1.4260 - accuracy: 0.7782 - val_loss: 1.4906 - val_accuracy: 0.7766\n", + "Epoch 52/100\n", + "64/64 [==============================] - 60s 934ms/step - loss: 1.4190 - accuracy: 0.7793 - val_loss: 1.4786 - val_accuracy: 0.7757\n", + "Epoch 53/100\n", + "64/64 [==============================] - 60s 936ms/step - loss: 1.4112 - accuracy: 0.7801 - val_loss: 1.4856 - val_accuracy: 0.7739\n", + "Epoch 54/100\n", + "64/64 [==============================] - 59s 923ms/step - loss: 1.4045 - accuracy: 0.7808 - val_loss: 1.4742 - val_accuracy: 0.7785\n", + "Epoch 55/100\n", + "64/64 [==============================] - 60s 945ms/step - loss: 1.3971 - accuracy: 0.7818 - val_loss: 1.4632 - val_accuracy: 0.7782\n", + "Epoch 56/100\n", + "64/64 [==============================] - 61s 949ms/step - loss: 1.3900 - accuracy: 0.7825 - val_loss: 1.4676 - val_accuracy: 0.7771\n", + "Epoch 57/100\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "64/64 [==============================] - 60s 944ms/step - loss: 1.3835 - accuracy: 0.7832 - val_loss: 1.4466 - val_accuracy: 0.7804\n", + "Epoch 58/100\n", + "64/64 [==============================] - 61s 949ms/step - loss: 1.3759 - accuracy: 0.7841 - val_loss: 1.4625 - val_accuracy: 0.7795\n", + "Epoch 59/100\n", + "64/64 [==============================] - 60s 940ms/step - loss: 1.3696 - accuracy: 0.7847 - val_loss: 1.4471 - val_accuracy: 0.7812\n", + "Epoch 60/100\n", + "64/64 [==============================] - 60s 942ms/step - loss: 1.3634 - accuracy: 0.7855 - val_loss: 1.4335 - val_accuracy: 0.7816\n", + "Epoch 61/100\n", + "64/64 [==============================] - 61s 949ms/step - loss: 1.3582 - accuracy: 0.7860 - val_loss: 1.4254 - val_accuracy: 0.7827\n", + "Epoch 62/100\n", + "64/64 [==============================] - 60s 942ms/step - loss: 1.3511 - accuracy: 0.7867 - val_loss: 1.4250 - val_accuracy: 0.7838\n", + "Epoch 63/100\n", + "64/64 [==============================] - 61s 951ms/step - loss: 1.3444 - accuracy: 0.7872 - val_loss: 1.4377 - val_accuracy: 0.7817\n", + "Epoch 64/100\n", + "64/64 [==============================] - 61s 953ms/step - loss: 1.3388 - accuracy: 0.7878 - val_loss: 1.4179 - val_accuracy: 0.7847\n", + "Epoch 65/100\n", + "64/64 [==============================] - 61s 947ms/step - loss: 1.3322 - accuracy: 0.7885 - val_loss: 1.4147 - val_accuracy: 0.7843\n", + "Epoch 66/100\n", + "64/64 [==============================] - 59s 930ms/step - loss: 1.3272 - accuracy: 0.7891 - val_loss: 1.4056 - val_accuracy: 0.7853\n", + "Epoch 67/100\n", + "64/64 [==============================] - 60s 932ms/step - loss: 1.3206 - accuracy: 0.7894 - val_loss: 1.3961 - val_accuracy: 0.7860\n", + "Epoch 68/100\n", + "64/64 [==============================] - 60s 936ms/step - loss: 1.3153 - accuracy: 0.7901 - val_loss: 1.4083 - val_accuracy: 0.7832\n", + "Epoch 69/100\n", + "64/64 [==============================] - 60s 943ms/step - loss: 1.3101 - accuracy: 0.7905 - val_loss: 1.3947 - val_accuracy: 0.7862\n", + "Epoch 70/100\n", + "64/64 [==============================] - 61s 949ms/step - loss: 1.3037 - accuracy: 0.7911 - val_loss: 1.3884 - val_accuracy: 0.7871\n", + "Epoch 71/100\n", + "64/64 [==============================] - 61s 950ms/step - loss: 1.2975 - accuracy: 0.7916 - val_loss: 1.3885 - val_accuracy: 0.7867\n", + "Epoch 72/100\n", + "64/64 [==============================] - 61s 951ms/step - loss: 1.2933 - accuracy: 0.7921 - val_loss: 1.3808 - val_accuracy: 0.7879\n", + "Epoch 73/100\n", + "64/64 [==============================] - 61s 959ms/step - loss: 1.2880 - accuracy: 0.7923 - val_loss: 1.3793 - val_accuracy: 0.7885\n", + "Epoch 74/100\n", + "64/64 [==============================] - 62s 963ms/step - loss: 1.2819 - accuracy: 0.7930 - val_loss: 1.3728 - val_accuracy: 0.7885\n", + "Epoch 75/100\n", + "64/64 [==============================] - 62s 965ms/step - loss: 1.2747 - accuracy: 0.7937 - val_loss: 1.3722 - val_accuracy: 0.7874\n", + "Epoch 76/100\n", + "64/64 [==============================] - 62s 963ms/step - loss: 1.2700 - accuracy: 0.7942 - val_loss: 1.3701 - val_accuracy: 0.7883\n", + "Epoch 77/100\n", + "64/64 [==============================] - 62s 969ms/step - loss: 1.2646 - accuracy: 0.7946 - val_loss: 1.3617 - val_accuracy: 0.7885\n", + "Epoch 78/100\n", + "64/64 [==============================] - 62s 974ms/step - loss: 1.2595 - accuracy: 0.7951 - val_loss: 1.3726 - val_accuracy: 0.7893\n", + "Epoch 79/100\n", + "64/64 [==============================] - 60s 942ms/step - loss: 1.2546 - accuracy: 0.7957 - val_loss: 1.3537 - val_accuracy: 0.7900\n", + "Epoch 80/100\n", + "64/64 [==============================] - 61s 947ms/step - loss: 1.2486 - accuracy: 0.7960 - val_loss: 1.3432 - val_accuracy: 0.7910\n", + "Epoch 81/100\n", + "64/64 [==============================] - 60s 940ms/step - loss: 1.2426 - accuracy: 0.7966 - val_loss: 1.3421 - val_accuracy: 0.7915\n", + "Epoch 82/100\n", + "64/64 [==============================] - 61s 946ms/step - loss: 1.2385 - accuracy: 0.7969 - val_loss: 1.3483 - val_accuracy: 0.7898\n", + "Epoch 83/100\n", + "64/64 [==============================] - 61s 948ms/step - loss: 1.2327 - accuracy: 0.7977 - val_loss: 1.3351 - val_accuracy: 0.7914\n", + "Epoch 84/100\n", + "64/64 [==============================] - 61s 960ms/step - loss: 1.2276 - accuracy: 0.7980 - val_loss: 1.3583 - val_accuracy: 0.7906\n", + "Epoch 85/100\n", + "64/64 [==============================] - 61s 954ms/step - loss: 1.2221 - accuracy: 0.7984 - val_loss: 1.3335 - val_accuracy: 0.7908\n", + "Epoch 86/100\n", + "64/64 [==============================] - 60s 941ms/step - loss: 1.2168 - accuracy: 0.7989 - val_loss: 1.3270 - val_accuracy: 0.7922\n", + "Epoch 87/100\n", + "64/64 [==============================] - 60s 937ms/step - loss: 1.2116 - accuracy: 0.7995 - val_loss: 1.3230 - val_accuracy: 0.7936\n", + "Epoch 88/100\n", + "64/64 [==============================] - 61s 958ms/step - loss: 1.2069 - accuracy: 0.7999 - val_loss: 1.3136 - val_accuracy: 0.7938\n", + "Epoch 89/100\n", + "64/64 [==============================] - 60s 944ms/step - loss: 1.2012 - accuracy: 0.8008 - val_loss: 1.3188 - val_accuracy: 0.7939\n", + "Epoch 90/100\n", + "64/64 [==============================] - 60s 934ms/step - loss: 1.1955 - accuracy: 0.8010 - val_loss: 1.3250 - val_accuracy: 0.7948\n", + "Epoch 91/100\n", + "64/64 [==============================] - 60s 936ms/step - loss: 1.1907 - accuracy: 0.8013 - val_loss: 1.3162 - val_accuracy: 0.7935\n", + "Epoch 92/100\n", + "64/64 [==============================] - 60s 936ms/step - loss: 1.1863 - accuracy: 0.8020 - val_loss: 1.3137 - val_accuracy: 0.7937\n", + "Epoch 93/100\n", + "64/64 [==============================] - 60s 937ms/step - loss: 1.1805 - accuracy: 0.8026 - val_loss: 1.3008 - val_accuracy: 0.7956\n", + "Epoch 94/100\n", + "64/64 [==============================] - 60s 940ms/step - loss: 1.1764 - accuracy: 0.8025 - val_loss: 1.3036 - val_accuracy: 0.7945\n", + "Epoch 95/100\n", + "64/64 [==============================] - 60s 940ms/step - loss: 1.1711 - accuracy: 0.8035 - val_loss: 1.2928 - val_accuracy: 0.7961\n", + "Epoch 96/100\n", + "64/64 [==============================] - 60s 933ms/step - loss: 1.1667 - accuracy: 0.8037 - val_loss: 1.2971 - val_accuracy: 0.7951\n", + "Epoch 97/100\n", + "64/64 [==============================] - 60s 943ms/step - loss: 1.1608 - accuracy: 0.8046 - val_loss: 1.2945 - val_accuracy: 0.7963\n", + "Epoch 98/100\n", + "64/64 [==============================] - 60s 933ms/step - loss: 1.1558 - accuracy: 0.8047 - val_loss: 1.3045 - val_accuracy: 0.7940\n", + "Epoch 99/100\n", + "64/64 [==============================] - 60s 936ms/step - loss: 1.1505 - accuracy: 0.8055 - val_loss: 1.2926 - val_accuracy: 0.7955\n", + "Epoch 100/100\n", + "64/64 [==============================] - 60s 935ms/step - loss: 1.1460 - accuracy: 0.8059 - val_loss: 1.2775 - val_accuracy: 0.7976\n" + ] + } + ], + "source": [ + "#create callback monitoring loss with patience 3\n", + "#stop training if loss does not improve in 3 straight epochs\n", + "callback = EarlyStopping(monitor='loss', patience=3)\n", + "\n", + "#fit the model\n", + "history = model.fit(train_eng_encoded, train_fr_encoded.reshape(train_fr_encoded.shape[0], train_fr_encoded.shape[1], 1), \n", + " epochs=100, batch_size=512, callbacks=[callback], validation_split = 0.1, verbose=1)#callbacks=[callback]," + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "fbb6f265", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "#save the model\n", + "model.save('seq2seq.h5')\n", + "\n", + "#plot the loss history over the epochs\n", + "plt.plot(history.history['loss'])\n", + "plt.plot(history.history['val_loss'])\n", + "plt.legend(['train','validation'])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "57f64bf7", + "metadata": {}, + "outputs": [], + "source": [ + "#Create a dictionary from word indices to words in the French vocabulary\n", + "fr_idx2word = {word:key for key, word in fr_tokenizer.word_index.items()}\n", + "\n", + "#Define a function that generates prediction text from prediction tokenized sequences \n", + "def generate_prediction_texts(pred_seqs, idx2word_map):\n", + " pred_texts = []\n", + " for prediction in preds:\n", + " #create an empty list to store the translated words in the prediction sequence\n", + " translated = []\n", + " for i in range(len(prediction)):\n", + " #Get the word corresponding to the current word token index in the French vocabulary, or None if not found\n", + " word_pred = fr_idx2word.get(prediction[i], None)\n", + " if i > 0:\n", + " #if the current word is the same as the previous word, add an empty string to the list\n", + " if (word_pred == fr_idx2word.get(prediction[i-1], None)) or (word_pred == None):\n", + " translated.append('')\n", + " #add the current word to the translated list\n", + " else:\n", + " translated.append(word_pred) \n", + " else:\n", + " #If this is the first word in the sequence and it is None, add an empty string to the translated list\n", + " if(word_pred == None):\n", + " translated.append('')\n", + " #Otherwise, add the current word to the translated list\n", + " else:\n", + " translated.append(word_pred) \n", + " #Join the translated words into a string and add it to the list of predicted texts\n", + " pred_texts.append(' '.join(translated))\n", + " #Return predicted texts\n", + " return pred_texts" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d85a1638", + "metadata": {}, + "outputs": [], + "source": [ + "#get the model\n", + "model=load_model('./seq2seq.h5')" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "2b2eb44b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "125/125 [==============================] - 5s 41ms/step\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sourcetargetpredicted
13787don t talk to mene me parle pasne me pas
5128he is awesomeil est génialil est
9125they quarreledelles se sont disputéeselles ont
35041i have a bad sunburnj ai un mauvais coup de soleilj ai un
\n", + "
" + ], + "text/plain": [ + " source target \\\n", + "13787 don t talk to me ne me parle pas \n", + "5128 he is awesome il est génial \n", + "9125 they quarreled elles se sont disputées \n", + "35041 i have a bad sunburn j ai un mauvais coup de soleil \n", + "\n", + " predicted \n", + "13787 ne me pas \n", + "5128 il est \n", + "9125 elles ont \n", + "35041 j ai un " + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Predict the target sequences for the test data\n", + "pred_probs = model.predict(test_eng_enc.reshape((test_eng_enc.shape[0], test_eng_enc.shape[1])), verbose=1)\n", + "preds = [np.argmax(i, axis=1) for i in pred_probs]\n", + "\n", + "#convert the prediction from indices to text using the function defined before\n", + "pred_texts = generate_prediction_texts(pred_seqs=preds, idx2word_map=fr_idx2word)\n", + "#plot the original english sentence, the target french and the predicted french\n", + "pred_df = pd.DataFrame({'source': test_eng, 'target' : test_fr, 'predicted' : pred_texts})\n", + "pred_df.head(50)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "4a8699fd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sourcetargetpredicted
13787don t talk to mene me parle pasne me pas
5128he is awesomeil est génialil est
9125they quarreledelles se sont disputéeselles ont
35041i have a bad sunburnj ai un mauvais coup de soleilj ai un
3510i am workingje suis en train de travaillerje suis
\n", + "
" + ], + "text/plain": [ + " source target \\\n", + "13787 don t talk to me ne me parle pas \n", + "5128 he is awesome il est génial \n", + "9125 they quarreled elles se sont disputées \n", + "35041 i have a bad sunburn j ai un mauvais coup de soleil \n", + "3510 i am working je suis en train de travailler \n", + "\n", + " predicted \n", + "13787 ne me pas \n", + "5128 il est \n", + "9125 elles ont \n", + "35041 j ai un \n", + "3510 je suis " + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred_df[11:16]" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "baca08b4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\lucas\\anaconda3\\lib\\site-packages\\nltk\\translate\\bleu_score.py:552: UserWarning: \n", + "The hypothesis contains 0 counts of 2-gram overlaps.\n", + "Therefore the BLEU score evaluates to 0, independently of\n", + "how many N-gram overlaps of lower order it contains.\n", + "Consider using lower n-gram order or use SmoothingFunction()\n", + " warnings.warn(_msg)\n", + "C:\\Users\\lucas\\anaconda3\\lib\\site-packages\\nltk\\translate\\bleu_score.py:552: UserWarning: \n", + "The hypothesis contains 0 counts of 3-gram overlaps.\n", + "Therefore the BLEU score evaluates to 0, independently of\n", + "how many N-gram overlaps of lower order it contains.\n", + "Consider using lower n-gram order or use SmoothingFunction()\n", + " warnings.warn(_msg)\n", + "C:\\Users\\lucas\\anaconda3\\lib\\site-packages\\nltk\\translate\\bleu_score.py:552: UserWarning: \n", + "The hypothesis contains 0 counts of 4-gram overlaps.\n", + "Therefore the BLEU score evaluates to 0, independently of\n", + "how many N-gram overlaps of lower order it contains.\n", + "Consider using lower n-gram order or use SmoothingFunction()\n", + " warnings.warn(_msg)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average BLEU score: 0.0319\n", + "Corpus-level BLEU score: 0.1193\n" + ] + } + ], + "source": [ + "#Tokenize the test sentences and predicted translations\n", + "ref_tokenized = [[word_tokenize(sentence.lower())] for sentence in test_fr]\n", + "pred_tokenized = [word_tokenize(sentence.lower()) for sentence in pred_texts]\n", + "\n", + "#Compute BLEU score\n", + "bleu_scores = []\n", + "for i in range(len(pred_tokenized)):\n", + " bleu_scores.append(sentence_bleu(ref_tokenized[i], pred_tokenized[i]))\n", + "\n", + "#Compute corpus-level BLEU score\n", + "corpus_bleu_score = corpus_bleu(ref_tokenized, pred_tokenized)\n", + "\n", + "print(f\"Average BLEU score: {np.mean(bleu_scores):.4f}\")\n", + "print(f\"Corpus-level BLEU score: {corpus_bleu_score:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb554898", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}