From 5f9812fd4546a36766f263d4733b34fc3a029c15 Mon Sep 17 00:00:00 2001 From: Brian-Weloba Date: Tue, 1 Apr 2025 20:42:24 +0100 Subject: [PATCH 1/7] Implement event deletion functionality; add DeleteEvent servlet for handling event deletions, enhance event listing with delete confirmation modal, and improve form validation in create.jsp and update.jsp --- .../java/dev/brianweloba/dao/EventDAO.java | 19 ++- .../dev/brianweloba/servlet/DeleteEvent.java | 76 +++++++++- .../dev/brianweloba/servlet/EventServlet.java | 13 ++ .../dev/brianweloba/servlet/UpdateEvent.java | 9 +- .../src/main/webapp/WEB-INF/views/create.jsp | 6 +- .../src/main/webapp/WEB-INF/views/events.jsp | 117 ++++++++++++++-- .../main/webapp/WEB-INF/views/eventsBody.jsp | 2 - .../src/main/webapp/WEB-INF/views/update.jsp | 130 ++++++++++++++++-- emmas/src/main/webapp/css/styles.css | 90 ++++++++++++ 9 files changed, 423 insertions(+), 39 deletions(-) diff --git a/emmas/src/main/java/dev/brianweloba/dao/EventDAO.java b/emmas/src/main/java/dev/brianweloba/dao/EventDAO.java index 2713063..e282ff3 100644 --- a/emmas/src/main/java/dev/brianweloba/dao/EventDAO.java +++ b/emmas/src/main/java/dev/brianweloba/dao/EventDAO.java @@ -108,6 +108,7 @@ public void update(Event event, String token) { manager.getTransaction().begin(); eventToUpdate.setTitle(event.getTitle()); + eventToUpdate.setEventHost(event.getEventHost()); eventToUpdate.setEventType(event.getEventType()); eventToUpdate.setEventLocation(event.getEventLocation()); eventToUpdate.setDescription(event.getDescription()); @@ -126,15 +127,25 @@ public void update(Event event, String token) { } } - public void delete(Long id) { + public boolean delete(Long id, String token) { EntityManager manager = HibernateUtil.getEntityManager(); try { manager.getTransaction().begin(); - Event event = manager.find(Event.class, id); - if (event != null) { - manager.remove(event); + Event event = manager.createQuery( + "SELECT e FROM Event e WHERE e.id = :id AND e.editToken = :token", Event.class) + .setParameter("id", id) + .setParameter("token", token) + .getResultList() + .stream() + .findFirst() + .orElse(null); + + if (event == null) { + return false; } + manager.remove(event); manager.getTransaction().commit(); + return true; } catch (Exception e) { if (manager.getTransaction().isActive()) { manager.getTransaction().rollback(); diff --git a/emmas/src/main/java/dev/brianweloba/servlet/DeleteEvent.java b/emmas/src/main/java/dev/brianweloba/servlet/DeleteEvent.java index cb9d9f9..86f1d20 100644 --- a/emmas/src/main/java/dev/brianweloba/servlet/DeleteEvent.java +++ b/emmas/src/main/java/dev/brianweloba/servlet/DeleteEvent.java @@ -1,4 +1,76 @@ package dev.brianweloba.servlet; -public class DeleteEvent { -} +import dev.brianweloba.dao.EventDAO; +import dev.brianweloba.model.Event; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import java.io.IOException; + +@WebServlet("/events/delete") +public class DeleteEvent extends HttpServlet { + private final EventDAO eventDAO = new EventDAO(); + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + try { + String idParam = request.getParameter("id"); + if (idParam == null || idParam.trim().isEmpty()) { + response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Event ID is required"); + return; + } + + Long eventId = Long.parseLong(idParam); + + String token = getEditTokenFromCookies(request, eventId); + if (token == null) { + request.setAttribute("error", "Invalid token. You don't have permission to delete this event."); + request.getRequestDispatcher("/WEB-INF/views/error.jsp").forward(request, response); + return; + } + + Event event = eventDAO.findByIdAndToken(eventId, token); + if (event == null) { + request.setAttribute("error", "Event not found or you don't have permission to delete this event."); + request.getRequestDispatcher("/WEB-INF/views/error.jsp").forward(request, response); + return; + } + + boolean success = eventDAO.delete(eventId, token); + + if (success) { + request.getSession().setAttribute("statusMessage", "Event deleted successfully!"); + response.sendRedirect(request.getContextPath() + "/events"); + } else { + request.setAttribute("error", "Failed to delete event."); + request.getRequestDispatcher("/WEB-INF/views/error.jsp").forward(request, response); + } + + } catch (NumberFormatException e) { + request.setAttribute("error", "Invalid event ID format"); + request.getRequestDispatcher("/WEB-INF/views/error.jsp").forward(request, response); + } catch (Exception e) { + e.printStackTrace(); + request.setAttribute("error", "Error deleting event: " + e.getMessage()); + request.getRequestDispatcher("/WEB-INF/views/error.jsp").forward(request, response); + } + } + + public static String getEditTokenFromCookies(HttpServletRequest request, Long eventId) { + Cookie[] cookies = request.getCookies(); + if (cookies != null) { + for (Cookie cookie : cookies) { + if (cookie.getName().equals("event_" + eventId + "_token")) { + return cookie.getValue(); + } + } + } + return null; + } +} \ No newline at end of file diff --git a/emmas/src/main/java/dev/brianweloba/servlet/EventServlet.java b/emmas/src/main/java/dev/brianweloba/servlet/EventServlet.java index ac6e0e2..0972a82 100644 --- a/emmas/src/main/java/dev/brianweloba/servlet/EventServlet.java +++ b/emmas/src/main/java/dev/brianweloba/servlet/EventServlet.java @@ -9,7 +9,11 @@ import jakarta.servlet.http.HttpServletResponse; import java.io.IOException; +import java.util.HashMap; import java.util.List; +import java.util.Map; + +import static dev.brianweloba.servlet.UpdateEvent.getEditTokenFromCookies; @WebServlet("/events") public class EventServlet extends HttpServlet { @@ -38,6 +42,15 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) List events = eventDAO.findPaginated(startIndex, pageSize); + Map editTokens = new HashMap<>(); + for (Event event : events) { + String token = getEditTokenFromCookies(request, event.getId()); + if (token != null) { + editTokens.put(event.getId(), token); + } + } + request.setAttribute("editTokens", editTokens); + request.setAttribute("events", events); request.setAttribute("currentPage", currentPage); request.setAttribute("pageSize", pageSize); diff --git a/emmas/src/main/java/dev/brianweloba/servlet/UpdateEvent.java b/emmas/src/main/java/dev/brianweloba/servlet/UpdateEvent.java index d173657..64eafd3 100644 --- a/emmas/src/main/java/dev/brianweloba/servlet/UpdateEvent.java +++ b/emmas/src/main/java/dev/brianweloba/servlet/UpdateEvent.java @@ -59,12 +59,19 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) if(ValidationUtil.validateEvent(request,response,event,"/WEB-INF/views/update.jsp")) return; + event.setEventType(request.getParameter("eventType").toUpperCase()); + + if(event.getEventType().isEmpty()){ + request.setAttribute("eventTypeError", "Event Type is empty"); + request.getRequestDispatcher("/WEB-INF/views/prediction.jsp").forward(request, response); + return; + } eventDAO.update(event, token); response.sendRedirect(request.getContextPath() + "/events"); } - private String getEditTokenFromCookies(HttpServletRequest request, Long eventId) { + public static String getEditTokenFromCookies(HttpServletRequest request, Long eventId) { Cookie[] cookies = request.getCookies(); if (cookies != null) { for (Cookie cookie : cookies) { diff --git a/emmas/src/main/webapp/WEB-INF/views/create.jsp b/emmas/src/main/webapp/WEB-INF/views/create.jsp index 6253659..b3cc864 100644 --- a/emmas/src/main/webapp/WEB-INF/views/create.jsp +++ b/emmas/src/main/webapp/WEB-INF/views/create.jsp @@ -9,7 +9,7 @@ let isValid = true; const title = document.getElementById('title').value; const eventHost = document.getElementById('eventHost').value; - const eventLocation = document.getElementById('location').value; + const eventLocation = document.getElementById('eventLocation').value; const eventCapacity = document.getElementById('eventCapacity').value; const eventDate = document.getElementById('eventDate').value; @@ -134,8 +134,8 @@
- - Location + ${requestScope.locationError} diff --git a/emmas/src/main/webapp/WEB-INF/views/events.jsp b/emmas/src/main/webapp/WEB-INF/views/events.jsp index cf02c98..aadbef5 100644 --- a/emmas/src/main/webapp/WEB-INF/views/events.jsp +++ b/emmas/src/main/webapp/WEB-INF/views/events.jsp @@ -102,6 +102,51 @@ } }); } + + // For delete confirmation modal + function showDeleteConfirmation(eventId) { + const modal = document.getElementById('deleteModal'); + const confirmButton = document.getElementById('confirmDelete'); + const basePath = "${pageContext.request.contextPath}"; + // Set the correct delete URL + confirmButton.href = basePath + "/events/delete?id=" + eventId; + + // Show the modal + modal.classList.remove('hidden'); + } + + function closeDeleteModal() { + document.getElementById('deleteModal').classList.add('hidden'); + } + + // Toast notification function + function showToast(message) { + const toast = document.getElementById('toast'); + const toastMessage = document.getElementById('toastMessage'); + + toastMessage.textContent = message; + + // Show the toast (both transform and opacity) + toast.classList.remove('translate-y-full'); + toast.classList.remove('opacity-0'); + toast.classList.remove('pointer-events-none'); + + // Hide the toast after 3 seconds + setTimeout(function () { + toast.classList.add('translate-y-full'); + toast.classList.add('opacity-0'); + toast.classList.add('pointer-events-none'); + }, 3000); + } + + document.addEventListener('DOMContentLoaded', function () { + const statusMessage = "${sessionScope.statusMessage}"; + + if (statusMessage && statusMessage.trim() !== "") { + showToast(statusMessage); + <% session.removeAttribute("statusMessage"); %> + } + }); @@ -120,7 +165,8 @@
- + + +
+ - +
+
@@ -314,6 +376,33 @@ <%@ include file="../components/footer.jsp" %> + + + + +
+
+ + + + Event deleted successfully! +
+
diff --git a/emmas/src/main/webapp/WEB-INF/views/eventsBody.jsp b/emmas/src/main/webapp/WEB-INF/views/eventsBody.jsp index b4dbecd..9e28c19 100644 --- a/emmas/src/main/webapp/WEB-INF/views/eventsBody.jsp +++ b/emmas/src/main/webapp/WEB-INF/views/eventsBody.jsp @@ -161,8 +161,6 @@ function editEvent() { const basePath = "${pageContext.request.contextPath}"; - console.log("basePath: ", basePath); - console.log("full path: ", basePath + "=events/update?id=" + _id); window.location.href = basePath + "/events/update?id=" + _id; } diff --git a/emmas/src/main/webapp/WEB-INF/views/update.jsp b/emmas/src/main/webapp/WEB-INF/views/update.jsp index 82475ba..37f6be4 100644 --- a/emmas/src/main/webapp/WEB-INF/views/update.jsp +++ b/emmas/src/main/webapp/WEB-INF/views/update.jsp @@ -1,8 +1,99 @@ +<%@ taglib prefix="fmt" uri="jakarta.tags.fmt" %> <%@ page contentType="text/html;charset=UTF-8" %> Add Event + <%@ include file="../components/navbar.jsp" %> @@ -14,28 +105,36 @@
+ class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline ${not empty requestScope.titleError ? 'border-red-500' : ''}"> + + ${requestScope.titleError} +
+ class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline ${not empty requestScope.hostError ? 'border-red-500' : ''}"> + + ${requestScope.hostError} +
- -
+ + + ${requestScope.eventTypeError} +
- + + + + ${requestScope.dateError} +
@@ -47,9 +146,11 @@
+ class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline ${not empty requestScope.locationError ? 'border-red-500' : ''}"> + + ${requestScope.locationError} +
-
@@ -57,6 +158,9 @@ class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline"> ${requestScope.event!=null ? requestScope.event.description : ''} + + ${requestScope.descriptionError} +
diff --git a/emmas/src/main/webapp/WEB-INF/views/events.jsp b/emmas/src/main/webapp/WEB-INF/views/events.jsp index aadbef5..32a501e 100644 --- a/emmas/src/main/webapp/WEB-INF/views/events.jsp +++ b/emmas/src/main/webapp/WEB-INF/views/events.jsp @@ -25,7 +25,7 @@ function changePageSize(size) { const url = new URL(window.location.href); - url.searchParams.set('page', 1); + url.searchParams.set('page', '1'); url.searchParams.set('size', size); window.location.href = url.toString(); } @@ -161,21 +161,6 @@
Showing ${requestScope.startIndex +1} to ${requestScope.endIndex} of ${requestScope.totalEvents} entries
- -
-
- -
- -
+ +
+ + + +
`, + ``, + ``, + ``, + ``, + `${event.eventCapacity}`, + ``, + `` + )"> + +
+ ${event.eventType} +
+ +
+ +
+

+ +

+
+ +
+
+ + +
+ +
+ + + +
+ Date: + +
+
+ + +
+ + + +
+ Location: + +
+
+ + +
+ + + +
+ Attendance: + + + + + + + ${totalGuests}/${event.eventCapacity} + (${rsvpCount} RSVPs) + +
+
+
+ + +
+

Description:

+

+ +

+
+ + +
+
+ Hosted by: + ${event.eventHost} +
+ +
+
+
+
+
+ +
+

No events available at this time

+
+
+
+
+ -
@@ -293,7 +273,6 @@ let _id; function openModal(id, title, date, description, location, capacity, host, eventType) { - // Escape HTML to prevent XSS const escapeHtml = (unsafe) => { return unsafe .replace(/&/g, "&") @@ -316,7 +295,6 @@ document.getElementById("eventModal").classList.remove("hidden"); - // Add event listener for escape key document.addEventListener('keydown', handleEscKey); } @@ -328,11 +306,9 @@ function closeModal() { document.getElementById("eventModal").classList.add("hidden"); - // Remove escape key listener when modal is closed document.removeEventListener('keydown', handleEscKey); } - // Close modal when clicking outside document.addEventListener('DOMContentLoaded', () => { const modal = document.getElementById('eventModal'); modal.addEventListener('click', (e) => { @@ -347,19 +323,16 @@ window.location.href = basePath + "/events/update?id=" + _id; } - // Toast notification function function showToast(message) { const toast = document.getElementById('toast'); const toastMessage = document.getElementById('toastMessage'); toastMessage.textContent = message; - // Show the toast (both transform and opacity) toast.classList.remove('translate-y-full'); toast.classList.remove('opacity-0'); toast.classList.remove('pointer-events-none'); - // Hide the toast after 3 seconds setTimeout(function () { toast.classList.add('translate-y-full'); toast.classList.add('opacity-0'); @@ -368,7 +341,6 @@ } function filterEventsByType(type) { - // Update active button state document.querySelectorAll('.filter-btn').forEach(btn => { if ((type === 'all' && !btn.dataset.type) || btn.dataset.type === type) { btn.classList.add('active-filter', 'bg-cyan-950', 'text-white'); diff --git a/emmas/src/main/webapp/WEB-INF/views/update.jsp b/emmas/src/main/webapp/WEB-INF/views/update.jsp index 37f6be4..7c769d4 100644 --- a/emmas/src/main/webapp/WEB-INF/views/update.jsp +++ b/emmas/src/main/webapp/WEB-INF/views/update.jsp @@ -122,7 +122,7 @@
- ${requestScope.eventTypeError} From 681175365126787e41f5f556a87da6900bd212fc Mon Sep 17 00:00:00 2001 From: Brian-Weloba Date: Wed, 2 Apr 2025 18:39:24 +0100 Subject: [PATCH 4/7] Enhance event classification; improve EventClassifierTester with balanced evaluation, detailed metrics logging, and cross-validation. Refactor EventTypeClassifier for better data handling and model training, including advanced class balancing techniques and enhanced text processing. --- .../weka/EventClassifierTester.java | 223 +++++- .../brianweloba/weka/EventTypeClassifier.java | 712 +++++++----------- .../weka/ImprovedEventClassifierTester.java | 319 -------- .../java/dev/brianweloba/weka/RunTest.java | 5 +- 4 files changed, 479 insertions(+), 780 deletions(-) delete mode 100644 emmas/src/main/java/dev/brianweloba/weka/ImprovedEventClassifierTester.java diff --git a/emmas/src/main/java/dev/brianweloba/weka/EventClassifierTester.java b/emmas/src/main/java/dev/brianweloba/weka/EventClassifierTester.java index 3393556..11ec270 100644 --- a/emmas/src/main/java/dev/brianweloba/weka/EventClassifierTester.java +++ b/emmas/src/main/java/dev/brianweloba/weka/EventClassifierTester.java @@ -1,15 +1,14 @@ package dev.brianweloba.weka; import dev.brianweloba.model.Event; +import weka.classifiers.Evaluation; import weka.core.Instances; import weka.core.converters.ArffSaver; + import java.io.File; -import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintWriter; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; public class EventClassifierTester { private final EventTypeClassifier classifier; @@ -22,8 +21,30 @@ public EventClassifierTester(String outputDir) { } public void runFullTest(List allEvents) throws Exception { - // Split data (80% train, 20% test) - int splitPoint = (int) (allEvents.size() * 0.8); + ClassifierLogger.log("Starting improved testing process with balanced evaluation"); + + // Ensure we have enough data + if (allEvents.size() < 50) { + ClassifierLogger.log("Warning: Small dataset size (" + allEvents.size() + " events)"); + } + + // Check class distribution before processing + Map typeCounts = new HashMap<>(); + for (Event event : allEvents) { + if (event.getEventType() != null) { + typeCounts.merge(event.getEventType(), 1, Integer::sum); + } + } + + // Log distribution + ClassifierLogger.log("Class distribution in full dataset:"); + typeCounts.forEach((type, count) -> + ClassifierLogger.log(String.format(" %s: %d (%.1f%%)", + type, count, (double)count * 100 / allEvents.size()))); + + // Split data (70% train, 30% test for better test set size) + Collections.shuffle(allEvents, new Random(42)); // Shuffle with fixed seed for reproducibility + int splitPoint = (int) (allEvents.size() * 0.7); List trainingEvents = allEvents.subList(0, splitPoint); List testEvents = allEvents.subList(splitPoint, allEvents.size()); @@ -35,7 +56,12 @@ public void runFullTest(List allEvents) throws Exception { // Test and save results ClassifierLogger.log("Starting testing with " + testEvents.size() + " events"); TestResults results = evaluateOnTestSet(testEvents); - saveTestResults(results, testEvents); + + // Calculate per-class metrics + Map classMetrics = calculatePerClassMetrics(testEvents); + + // Save detailed results + saveTestResults(results, testEvents, classMetrics); // Log metrics ClassifierLogger.logMetrics( @@ -43,27 +69,138 @@ public void runFullTest(List allEvents) throws Exception { results.recall, results.f1, trainingEvents.size(), testEvents.size() ); + + // Perform cross-validation for more robust evaluation + performCrossValidation(allEvents,10); + } + + // In the performCrossValidation method, create a new classifier for each fold + private void performCrossValidation(List allEvents, int folds) { + try { + // Generate a stratified sample of the data + Random rand = new Random(42); + Collections.shuffle(allEvents, rand); + + // Create CV partitions manually + int foldSize = allEvents.size() / folds; + double totalAccuracy = 0; + + for (int i = 0; i < folds; i++) { + // Split into training and testing for this fold + List cvTest = new ArrayList<>(allEvents.subList(i * foldSize, (i + 1) * foldSize)); + List cvTrain = new ArrayList<>(allEvents); + cvTrain.removeAll(cvTest); + + // Create a fresh classifier for this fold + EventTypeClassifier cvClassifier = new EventTypeClassifier(); + + // Train it on this fold's training data + cvClassifier.trainModel(cvTrain); + + // Test on this fold's test data + int correct = 0; + for (Event event : cvTest) { + String predicted = cvClassifier.predictEventType(event); + if (predicted.equals(event.getEventType())) { + correct++; + } + } + + double foldAccuracy = (double) correct / cvTest.size(); + totalAccuracy += foldAccuracy; + + ClassifierLogger.log(String.format("Fold %d accuracy: %.2f%%", i+1, foldAccuracy * 100)); + } + + double avgAccuracy = totalAccuracy / folds; + ClassifierLogger.log(String.format("Cross-validation average accuracy: %.2f%%", avgAccuracy * 100)); + + } catch (Exception e) { + ClassifierLogger.log("Error during cross-validation: " + e.getMessage()); + e.printStackTrace(); + } + } + + private void saveConfusionMatrix(Evaluation eval, Instances data) { + try (PrintWriter writer = new PrintWriter(new File(outputDir + "/confusion_matrix.txt"))) { + writer.println("=== Confusion Matrix ==="); + writer.println(eval.toMatrixString()); + + writer.println("\n=== Detailed Accuracy By Class ==="); + writer.println(eval.toClassDetailsString()); + } catch (Exception e) { + ClassifierLogger.log("Error saving confusion matrix: " + e.getMessage()); + } + } + + private Map calculatePerClassMetrics(List testEvents) { + Map metrics = new HashMap<>(); + Map truePositives = new HashMap<>(); + Map falsePositives = new HashMap<>(); + Map falseNegatives = new HashMap<>(); + Map totalActual = new HashMap<>(); + + // Count actual class distribution + for (Event event : testEvents) { + String actual = event.getEventType(); + totalActual.merge(actual, 1, Integer::sum); + } + + // Calculate metrics + for (Event event : testEvents) { + String actual = event.getEventType(); + String predicted = classifier.predictEventType(event); + + if (actual.equals(predicted)) { + truePositives.merge(actual, 1, Integer::sum); + } else { + falseNegatives.merge(actual, 1, Integer::sum); + falsePositives.merge(predicted, 1, Integer::sum); + } + } + + // Calculate precision, recall, and F1 for each class + for (String className : totalActual.keySet()) { + int tp = truePositives.getOrDefault(className, 0); + int fp = falsePositives.getOrDefault(className, 0); + int fn = falseNegatives.getOrDefault(className, 0); + int total = totalActual.get(className); + + double precision = (tp + fp == 0) ? 0 : (double) tp / (tp + fp); + double recall = (double) tp / total; + double f1 = (precision + recall == 0) ? 0 : 2 * precision * recall / (precision + recall); + + metrics.put(className, new ClassMetrics(precision, recall, f1, total)); + } + + return metrics; } private TestResults evaluateOnTestSet(List testEvents) { TestResults results = new TestResults(); int correct = 0; + Map classCounts = new HashMap<>(); + Map correctCounts = new HashMap<>(); try (PrintWriter writer = new PrintWriter(new File(outputDir + "/detailed_results.csv"))) { - writer.println("actual,predicted,confidence,title"); + writer.println("actual,predicted,confidence,title,is_correct"); for (Event event : testEvents) { String actual = event.getEventType(); String predicted = classifier.predictEventType(event); Map probs = classifier.getPredictionProbabilities(event); double confidence = probs.getOrDefault(predicted, 0.0); + boolean isCorrect = actual.equals(predicted); - writer.printf("%s,%s,%.4f,%s%n", + writer.printf("%s,%s,%.4f,%s,%b%n", actual, predicted, confidence, - event.getTitle().replace(",", "")); + event.getTitle().replace(",", ""), isCorrect); - if (actual.equals(predicted)) { + // Update class statistics + classCounts.merge(actual, 1, Integer::sum); + if (isCorrect) { correct++; + correctCounts.merge(actual, 1, Integer::sum); } else { ClassifierLogger.log(String.format( "MISCLASSIFIED: Actual=%s, Predicted=%s (%.2f%%) for '%s'", @@ -72,9 +209,35 @@ private TestResults evaluateOnTestSet(List testEvents) { } } + // Calculate global metrics results.accuracy = (double) correct / testEvents.size(); - // Calculate other metrics (precision, recall, f1) here - // You'd need a confusion matrix for multi-class metrics + + // Calculate confusion matrix-based metrics + double totalPrecision = 0; + double totalRecall = 0; + int classCount = classCounts.size(); + + for (String className : classCounts.keySet()) { + int truePositives = correctCounts.getOrDefault(className, 0); + int total = classCounts.get(className); + + // Calculate predicted count (for precision) + long predictedCount = testEvents.stream() + .filter(e -> classifier.predictEventType(e).equals(className)) + .count(); + + double precision = predictedCount == 0 ? 0 : (double) truePositives / predictedCount; + double recall = (double) truePositives / total; + + totalPrecision += precision; + totalRecall += recall; + } + + // Calculate macro-averaged metrics + results.precision = totalPrecision / classCount; + results.recall = totalRecall / classCount; + results.f1 = (results.precision + results.recall == 0) ? 0 : + 2 * results.precision * results.recall / (results.precision + results.recall); } catch (Exception e) { ClassifierLogger.log("Error during testing: " + e.getMessage()); @@ -95,13 +258,24 @@ private void saveModelAndData(List events, String prefix) throws Exceptio classifier.saveModel(outputDir + "/" + prefix + "_model.model"); } - private void saveTestResults(TestResults results, List testEvents) { + private void saveTestResults(TestResults results, List testEvents, + Map classMetrics) { try (PrintWriter writer = new PrintWriter(new File(outputDir + "/summary_results.txt"))) { writer.println("=== Classifier Test Results ==="); writer.printf("Accuracy: %.2f%%%n", results.accuracy * 100); - writer.printf("Precision: %.4f%n", results.precision); - writer.printf("Recall: %.4f%n", results.recall); - writer.printf("F1 Score: %.4f%n", results.f1); + writer.printf("Macro-Averaged Precision: %.4f%n", results.precision); + writer.printf("Macro-Averaged Recall: %.4f%n", results.recall); + writer.printf("Macro-Averaged F1 Score: %.4f%n", results.f1); + + writer.println("\n=== Per-Class Performance ==="); + for (Map.Entry entry : classMetrics.entrySet()) { + ClassMetrics metrics = entry.getValue(); + writer.printf("%s (n=%d):%n", entry.getKey(), metrics.count); + writer.printf(" Precision: %.4f%n", metrics.precision); + writer.printf(" Recall: %.4f%n", metrics.recall); + writer.printf(" F1 Score: %.4f%n", metrics.f1); + } + writer.println("\n=== Test Set Distribution ==="); // Count actual class distribution @@ -115,6 +289,7 @@ private void saveTestResults(TestResults results, List testEvents) { entry.getKey(), entry.getValue(), (double) entry.getValue() / testEvents.size() * 100); } + } catch (IOException e) { ClassifierLogger.log("Error saving test results: " + e.getMessage()); } @@ -126,4 +301,18 @@ private static class TestResults { double recall; double f1; } + + private static class ClassMetrics { + double precision; + double recall; + double f1; + int count; + + public ClassMetrics(double precision, double recall, double f1, int count) { + this.precision = precision; + this.recall = recall; + this.f1 = f1; + this.count = count; + } + } } \ No newline at end of file diff --git a/emmas/src/main/java/dev/brianweloba/weka/EventTypeClassifier.java b/emmas/src/main/java/dev/brianweloba/weka/EventTypeClassifier.java index fa35e1a..25572d1 100644 --- a/emmas/src/main/java/dev/brianweloba/weka/EventTypeClassifier.java +++ b/emmas/src/main/java/dev/brianweloba/weka/EventTypeClassifier.java @@ -1,325 +1,119 @@ package dev.brianweloba.weka; import dev.brianweloba.model.Event; -import weka.classifiers.Classifier; -import weka.classifiers.trees.RandomForest; +import weka.classifiers.bayes.NaiveBayes; +import weka.classifiers.evaluation.Evaluation; import weka.core.*; -import weka.core.stemmers.SnowballStemmer; import weka.core.stopwords.Rainbow; +import weka.core.tokenizers.NGramTokenizer; import weka.filters.Filter; import weka.filters.supervised.instance.SMOTE; import weka.filters.unsupervised.attribute.StringToWordVector; -import weka.core.tokenizers.NGramTokenizer; -import weka.classifiers.bayes.NaiveBayes; import java.util.*; import java.util.stream.Collectors; public class EventTypeClassifier { -// private RandomForest classifier; + private static final int MIN_CLASSES_FOR_CLASSIFICATION = 2; + private static final int MIN_INSTANCES_FOR_BALANCING = 6; + private static final int SMOTE_NEIGHBORS = 5; + private static final int SMOTE_PERCENTAGE = 100; + private static final int NGRAM_MIN_SIZE = 1; + private static final int NGRAM_MAX_SIZE = 2; + private static final int WORDS_TO_KEEP = 500; + private static final long RANDOM_SEED = 42L; + private NaiveBayes classifier; private Instances trainingHeader; public void trainModel(List events) { try { - // Filter out events with null eventType - List validEvents = events.stream() - .filter(e -> e.getEventType() != null && !e.getEventType().isEmpty()) - .toList(); - - if (validEvents.isEmpty()) { - throw new IllegalArgumentException("No valid events with eventType available for training"); - } + List validEvents = validateAndFilterEvents(events); + logClassDistribution(validEvents); - ClassifierLogger.log("Starting training with " + validEvents.size() + " events"); + Instances trainingData = createTrainingDataset(validEvents); + Instances filteredData = applyTextProcessingFilters(trainingData); + Instances balancedData = balanceDataset(filteredData); - // Get unique event types for class attribute - Set uniqueEventTypes = validEvents.stream() - .map(Event::getEventType) - .filter(Objects::nonNull) - .collect(Collectors.toSet()); + trainClassifier(balancedData); + evaluateModel(balancedData); - if (uniqueEventTypes.size() < 2) { - throw new IllegalArgumentException("Need at least 2 different event types for classification"); - } + ClassifierLogger.log("NaiveBayes model trained successfully with " + + balancedData.numInstances() + " balanced instances and " + + balancedData.numAttributes() + " features"); + } catch (Exception e) { + throw new RuntimeException("Failed to train model: " + e.getMessage(), e); + } + } - // Check class distribution before training - Map classCounts = validEvents.stream() - .collect(Collectors.groupingBy(Event::getEventType, Collectors.counting())); + private List validateAndFilterEvents(List events) { + List validEvents = events.stream() + .filter(e -> e.getEventType() != null && !e.getEventType().isEmpty()) + .toList(); - ClassifierLogger.log("Class distribution in training data:"); - classCounts.forEach((type, count) -> - ClassifierLogger.log(String.format(" %s: %d (%.1f%%)", - type, count, (double)count * 100 / validEvents.size()))); + if (validEvents.isEmpty()) { + throw new IllegalArgumentException("No valid events with eventType available for training"); + } - ArrayList eventTypes = new ArrayList<>(uniqueEventTypes); - ArrayList attributes = createAttributes(); + ClassifierLogger.log("Starting training with " + validEvents.size() + " events"); - // Change the name of the class attribute to avoid conflict - Attribute classAttribute = new Attribute("class_event_type", eventTypes); - attributes.add(classAttribute); + Set uniqueEventTypes = validEvents.stream() + .map(Event::getEventType) + .filter(Objects::nonNull) + .collect(Collectors.toSet()); - Instances trainingData = new Instances("EventData", attributes, 0); - trainingData.setClassIndex(trainingData.numAttributes() - 1); + if (uniqueEventTypes.size() < MIN_CLASSES_FOR_CLASSIFICATION) { + throw new IllegalArgumentException("Need at least " + MIN_CLASSES_FOR_CLASSIFICATION + + " different event types for classification"); + } - // Add instances - for (Event event : validEvents) { - addInstance(trainingData, event); - } + return validEvents; + } - // Apply text processing filter - Instances filteredData = applyStringToWordVectorFilter(trainingData); - - // Apply advanced class balancing techniques - if (filteredData.numInstances() >= uniqueEventTypes.size() * 6) { - ClassifierLogger.log("Applying advanced class balancing techniques"); - try { - // First, determine the target number of instances per class - // We'll aim for more balance than the current approach - int maxClassSize = 0; - Map classInstanceCounts = new HashMap<>(); - - for (int i = 0; i < filteredData.numInstances(); i++) { - double classValue = filteredData.instance(i).classValue(); - classInstanceCounts.merge(classValue, 1, Integer::sum); - maxClassSize = Math.max(maxClassSize, classInstanceCounts.get(classValue)); - } - - // Create balanced dataset with custom sampling - Instances balancedData = new Instances(filteredData, 0); - - for (double classValue : classInstanceCounts.keySet()) { - // Create a separate dataset for each class - Instances classData = new Instances(filteredData, 0); - for (int i = 0; i < filteredData.numInstances(); i++) { - if (filteredData.instance(i).classValue() == classValue) { - classData.add(filteredData.instance(i)); - } - } - - // For the majority class (usually MEETUP), undersample to reduce bias - if (classInstanceCounts.get(classValue) > maxClassSize / 2) { - int targetSize = maxClassSize / 2; - Random rand = new Random(42); - - // Randomly sample instances - classData.randomize(rand); - if (classData.numInstances() > targetSize) { - Instances sampledData = new Instances(classData, 0, targetSize); - for (int i = 0; i < sampledData.numInstances(); i++) { - balancedData.add(sampledData.instance(i)); - } - } else { - // If we don't have enough instances, use all of them - for (int i = 0; i < classData.numInstances(); i++) { - balancedData.add(classData.instance(i)); - } - } - } - // For minority classes, apply SMOTE more aggressively - else { - SMOTE smote = new SMOTE(); - smote.setInputFormat(classData); - // Use fewer neighbors for small classes - int neighbors = Math.min(5, classData.numInstances() - 1); - if (neighbors < 1) neighbors = 1; - smote.setNearestNeighbors(neighbors); - - // Increase percentage for more aggressive oversampling - // Calculate percentage to reach about equal representation - int currentSize = classInstanceCounts.get(classValue); - int targetSize = maxClassSize / 2; - int percentage = (targetSize * 100) / currentSize - 100; - smote.setPercentage(Math.max(100, percentage)); - - Instances smoteData = Filter.useFilter(classData, smote); - for (int i = 0; i < smoteData.numInstances(); i++) { - balancedData.add(smoteData.instance(i)); - } - } - } - - // Use the balanced dataset - filteredData = balancedData; - - // Log new balanced class distribution - Map balancedCounts = new HashMap<>(); - for (int i = 0; i < filteredData.numInstances(); i++) { - double classValue = filteredData.instance(i).classValue(); - balancedCounts.merge(classValue, 1, Integer::sum); - } - - ClassifierLogger.log("Class distribution after advanced balancing:"); - Instances finalFilteredData = filteredData; - balancedCounts.forEach((classVal, count) -> { - String className = finalFilteredData.classAttribute().value(classVal.intValue()); - ClassifierLogger.log(String.format(" %s: %d (%.1f%%)", - className, count, (double)count * 100 / finalFilteredData.numInstances())); - }); - - } catch (Exception e) { - ClassifierLogger.log("Advanced balancing failed: " + e.getMessage()); - ClassifierLogger.log("Falling back to original SMOTE approach"); - - // Fall back to original SMOTE approach - try { - SMOTE smote = new SMOTE(); - smote.setInputFormat(filteredData); - smote.setNearestNeighbors(5); - smote.setPercentage(100); // Create instances to match majority class - - // Apply SMOTE filter - Instances balancedData = Filter.useFilter(filteredData, smote); - - // Log new distribution - Map balancedCounts = new HashMap<>(); - for (int i = 0; i < balancedData.numInstances(); i++) { - double classValue = balancedData.instance(i).classValue(); - balancedCounts.merge(classValue, 1, Integer::sum); - } - - ClassifierLogger.log("Class distribution after basic SMOTE:"); - balancedCounts.forEach((classVal, count) -> { - String className = balancedData.classAttribute().value(classVal.intValue()); - ClassifierLogger.log(String.format(" %s: %d (%.1f%%)", - className, count, (double)count * 100 / balancedData.numInstances())); - }); - - // Use the balanced dataset for training - filteredData = balancedData; - } catch (Exception ex) { - ClassifierLogger.log("SMOTE also failed, using original data: " + ex.getMessage()); - } - } - } else { - ClassifierLogger.log("Not enough instances for balancing, using original data"); - } + private void logClassDistribution(List events) { + Map classCounts = events.stream() + .collect(Collectors.groupingBy(Event::getEventType, Collectors.counting())); - // Train classifier with improved configuration -// classifier = new RandomForest(); - classifier = new NaiveBayes(); + ClassifierLogger.log("Class distribution in training data:"); + classCounts.forEach((type, count) -> + ClassifierLogger.log(String.format(" %s: %d (%.1f%%)", + type, count, (double) count * 100 / events.size()))); + } - // Build the classifier - classifier.buildClassifier(filteredData); - trainingHeader = new Instances(filteredData, 0); + private Instances createTrainingDataset(List events) { + Set uniqueEventTypes = events.stream() + .map(Event::getEventType) + .collect(Collectors.toSet()); - // Log out-of-bag error estimate -// ClassifierLogger.log(String.format("Out-of-bag error estimate: %.2f%%", -// classifier.measureOutOfBagError() * 100)); - ClassifierLogger.log("NaiveBayes model trained successfully with " + - filteredData.numInstances() + " balanced instances and " + - filteredData.numAttributes() + " features"); + ArrayList attributes = createAttributes(); + Attribute classAttribute = new Attribute("class_event_type", new ArrayList<>(uniqueEventTypes)); + attributes.add(classAttribute); - // For RandomForest in Weka, we can't directly get feature importance - // but we can try to analyze the model in other ways - ClassifierLogger.log("Model trained successfully with " + - filteredData.numInstances() + " balanced instances and " + - filteredData.numAttributes() + " features"); + Instances trainingData = new Instances("EventData", attributes, 0); + trainingData.setClassIndex(trainingData.numAttributes() - 1); - } catch (Exception e) { - throw new RuntimeException("Failed to train model: " + e.getMessage(), e); + for (Event event : events) { + addInstance(trainingData, event); } + + return trainingData; } -// private Instances applyStringToWordVectorFilter(Instances data) throws Exception { -// // Save original class index and attribute -// int originalClassIndex = data.classIndex(); -// Attribute originalClassAttr = data.attribute(originalClassIndex); -// -// // Create a copy of the dataset -// Instances dataForFiltering = new Instances(data); -// String tempClassName = "___temp_class_attribute___"; -// dataForFiltering.renameAttribute(originalClassIndex, tempClassName); -// dataForFiltering.setClassIndex(-1); -// -// // Configure filter to only process text attributes -// StringToWordVector filter = new StringToWordVector(); -// filter.setAttributeIndices("1,2,3,4"); // Only process title, description, hostName, location -// -// // Increase word count - this was likely too limiting before -// filter.setWordsToKeep(200); // Increased from 1000 -// filter.setPeriodicPruning(100); -// -// filter.setLowerCaseTokens(true); -// filter.setTFTransform(true); -// filter.setIDFTransform(true); -// -// // Use NGram tokenizer instead of directly setting n-gram parameters -// NGramTokenizer tokenizer = new NGramTokenizer(); -// tokenizer.setNGramMinSize(1); -// tokenizer.setNGramMaxSize(3); -// filter.setTokenizer(tokenizer); -// -// // Add stemming to combine related words -// filter.setStemmer(new weka.core.stemmers.LovinsStemmer()); -// -// // Set minimum term frequency to filter very rare terms -// filter.setMinTermFreq(2); -// -// // Enable output word counts - may be useful for debugging -// filter.setOutputWordCounts(true); -// -// // Configure stopwords to remove common words -// filter.setStopwordsHandler(new weka.core.stopwords.Rainbow()); -// -// filter.setInputFormat(dataForFiltering); -// -// // Apply filter -// Instances filteredData = Filter.useFilter(dataForFiltering, filter); -// -// // Add the original class attribute back -// List classValues = new ArrayList<>(); -// for (int i = 0; i < originalClassAttr.numValues(); i++) { -// classValues.add(originalClassAttr.value(i)); -// } -// Attribute newClassAttr = new Attribute(originalClassAttr.name(), classValues); -// -// filteredData.insertAttributeAt(newClassAttr, filteredData.numAttributes()); -// filteredData.setClassIndex(filteredData.numAttributes() - 1); -// -// for (int i = 0; i < data.numInstances(); i++) { -// double classValue = data.instance(i).value(originalClassIndex); -// filteredData.instance(i).setValue(filteredData.classIndex(), classValue); -// } -// -// return filteredData; -// } - - private Instances applyStringToWordVectorFilter(Instances data) throws Exception { - // Save class attribute + private Instances applyTextProcessingFilters(Instances data) throws Exception { int classIndex = data.classIndex(); Attribute classAttr = data.attribute(classIndex); - // Create copy without class for filtering Instances dataForFiltering = new Instances(data); dataForFiltering.setClassIndex(-1); dataForFiltering.deleteAttributeAt(classIndex); - // Configure enhanced text processing - StringToWordVector filter = new StringToWordVector(); - filter.setAttributeIndices("1,2,3,4"); // Text attributes - filter.setWordsToKeep(500); - filter.setTFTransform(true); - filter.setIDFTransform(true); - - // Use N-grams - NGramTokenizer tokenizer = new NGramTokenizer(); - tokenizer.setNGramMinSize(1); - tokenizer.setNGramMaxSize(2); // Bigrams help capture phrases like "coding workshop" - filter.setTokenizer(tokenizer); - - // Add stopwords and stemming - filter.setStopwordsHandler(new Rainbow()); - filter.setStemmer(new SnowballStemmer()); - + StringToWordVector filter = createTextFilter(); filter.setInputFormat(dataForFiltering); Instances filteredData = Filter.useFilter(dataForFiltering, filter); - // Add back class attribute filteredData.insertAttributeAt(classAttr, filteredData.numAttributes()); filteredData.setClassIndex(filteredData.numAttributes() - 1); - // Copy class values for (int i = 0; i < data.numInstances(); i++) { filteredData.instance(i).setValue(filteredData.classIndex(), data.instance(i).value(classIndex)); @@ -328,179 +122,220 @@ private Instances applyStringToWordVectorFilter(Instances data) throws Exception return filteredData; } - private ArrayList createAttributes() { - ArrayList attributes = new ArrayList<>(); + private StringToWordVector createTextFilter() { + StringToWordVector filter = new StringToWordVector(); + filter.setAttributeIndices("1,2,3,4"); // Text attributes + filter.setWordsToKeep(WORDS_TO_KEEP); + filter.setTFTransform(true); + filter.setIDFTransform(true); - // Original attributes - attributes.add(new Attribute("title", (List) null)); - attributes.add(new Attribute("description", (List) null)); - attributes.add(new Attribute("hostName", (List) null)); - attributes.add(new Attribute("location", (List) null)); - attributes.add(new Attribute("eventCapacity")); - attributes.add(new Attribute("eventMonth")); - attributes.add(new Attribute("eventDayOfWeek")); - attributes.add(new Attribute("isWeekend")); - attributes.add(new Attribute("titleLength")); - attributes.add(new Attribute("descriptionWordCount")); + NGramTokenizer tokenizer = new NGramTokenizer(); + tokenizer.setNGramMinSize(NGRAM_MIN_SIZE); + tokenizer.setNGramMaxSize(NGRAM_MAX_SIZE); + filter.setTokenizer(tokenizer); + filter.setStopwordsHandler(new Rainbow()); + filter.setStemmer(new weka.core.stemmers.LovinsStemmer()); - return attributes; + return filter; + } + + private Instances balanceDataset(Instances data) throws Exception { + Set uniqueClasses = new HashSet<>(); + for (int i = 0; i < data.numInstances(); i++) { + uniqueClasses.add(data.instance(i).classValue()); + } + + if (data.numInstances() < uniqueClasses.size() * MIN_INSTANCES_FOR_BALANCING) { + ClassifierLogger.log("Not enough instances for balancing, using original data"); + return data; + } + + try { + return applyAdvancedBalancing(data); + } catch (Exception e) { + ClassifierLogger.log("Advanced balancing failed: " + e.getMessage()); + ClassifierLogger.log("Falling back to basic SMOTE"); + return applyBasicSMOTE(data); + } } private Instances applyAdvancedBalancing(Instances data) throws Exception { - // Calculate class distribution - Map classCounts = new HashMap<>(); + int maxClassSize = 0; + Map classInstanceCounts = new HashMap<>(); + for (int i = 0; i < data.numInstances(); i++) { double classValue = data.instance(i).classValue(); - classCounts.merge(classValue, 1, Integer::sum); + classInstanceCounts.merge(classValue, 1, Integer::sum); + maxClassSize = Math.max(maxClassSize, classInstanceCounts.get(classValue)); } - // Find median class size as target - List counts = new ArrayList<>(classCounts.values()); - Collections.sort(counts); - int targetSize = counts.get(counts.size() / 2); // Median - Instances balancedData = new Instances(data, 0); - // Process each class separately - for (double classValue : classCounts.keySet()) { - Instances classData = new Instances(data, 0); - for (int i = 0; i < data.numInstances(); i++) { - if (data.instance(i).classValue() == classValue) { - classData.add(data.instance(i)); - } - } + for (double classValue : classInstanceCounts.keySet()) { + Instances classData = extractClassInstances(data, classValue); + int targetSize = maxClassSize / 2; - int currentSize = classCounts.get(classValue); - - if (currentSize > targetSize) { - // Undersample majority classes (like WORKSHOP) - classData.randomize(new Random(42)); - for (int i = 0; i < Math.min(targetSize, classData.numInstances()); i++) { - balancedData.add(classData.instance(i)); - } - } else if (currentSize < targetSize) { - // Oversample minority classes - SMOTE smote = new SMOTE(); - smote.setInputFormat(classData); - - // Calculate needed percentage - int percentage = (int) (((double)targetSize / currentSize - 1) * 100); - smote.setPercentage(Math.max(100, percentage)); - smote.setNearestNeighbors(Math.min(5, classData.numInstances() - 1)); - - Instances oversampled = Filter.useFilter(classData, smote); - for (int i = 0; i < oversampled.numInstances(); i++) { - balancedData.add(oversampled.instance(i)); - } + if (classInstanceCounts.get(classValue) > targetSize) { + balancedData.addAll(downsampleClass(classData, targetSize)); } else { - // Already balanced - for (int i = 0; i < classData.numInstances(); i++) { - balancedData.add(classData.instance(i)); - } + balancedData.addAll(oversampleClassWithSMOTE(classData, targetSize)); } } + logBalancedDistribution(balancedData); return balancedData; } - private void addInstance(Instances data, Event event) { - double[] values = new double[data.numAttributes()]; + private Instances extractClassInstances(Instances data, double classValue) { + Instances classData = new Instances(data, 0); + for (int i = 0; i < data.numInstances(); i++) { + if (data.instance(i).classValue() == classValue) { + classData.add(data.instance(i)); + } + } + return classData; + } - // Set original values (0-9) - values[0] = data.attribute(0).addStringValue(preprocessText(event.getTitle())); - values[1] = data.attribute(1).addStringValue(preprocessText(event.getDescription())); - values[2] = data.attribute(2).addStringValue(preprocessText(event.getEventHost())); - values[3] = data.attribute(3).addStringValue(preprocessText(event.getEventLocation())); - values[4] = event.getEventCapacity(); + private Instances downsampleClass(Instances classData, int targetSize) { + classData.randomize(new Random(RANDOM_SEED)); + return new Instances(classData, 0, Math.min(targetSize, classData.numInstances())); + } - Calendar cal = Calendar.getInstance(); - cal.setTime(event.getEventDate()); - values[5] = cal.get(Calendar.MONTH) + 1; - values[6] = cal.get(Calendar.DAY_OF_WEEK); - values[7] = (values[6] == Calendar.SUNDAY || values[6] == Calendar.SATURDAY) ? 1 : 0; - values[8] = event.getTitle() != null ? event.getTitle().length() : 0; - values[9] = event.getDescription() != null ? event.getDescription().split("\\s+").length : 0; + private Instances oversampleClassWithSMOTE(Instances classData, int targetSize) throws Exception { + int currentSize = classData.numInstances(); + int percentage = (targetSize * 100) / currentSize - 100; - // Check for keyword presence (indices 10-25) - String title = event.getTitle() != null ? event.getTitle().toLowerCase() : ""; - String desc = event.getDescription() != null ? event.getDescription().toLowerCase() : ""; + SMOTE smote = new SMOTE(); + smote.setInputFormat(classData); + smote.setNearestNeighbors(Math.min(SMOTE_NEIGHBORS, currentSize - 1)); + smote.setPercentage(Math.max(100, percentage)); -// Map> keywords = createEventTypeKeywords(); - int featureIndex = 10; + return Filter.useFilter(classData, smote); + } + private Instances applyBasicSMOTE(Instances data) { + try { + SMOTE smote = new SMOTE(); + smote.setInputFormat(data); + smote.setNearestNeighbors(SMOTE_NEIGHBORS); + smote.setPercentage(SMOTE_PERCENTAGE); + + Instances balancedData = Filter.useFilter(data, smote); + logBalancedDistribution(balancedData); + return balancedData; + } catch (Exception e) { + ClassifierLogger.log("SMOTE failed, using original data: " + e.getMessage()); + return data; + } + } - // Class value - Attribute classAttr = data.classAttribute(); - int classValueIndex = classAttr.indexOfValue(event.getEventType()); - values[data.classIndex()] = classValueIndex; + private void logBalancedDistribution(Instances data) { + Map balancedCounts = new HashMap<>(); + for (int i = 0; i < data.numInstances(); i++) { + double classValue = data.instance(i).classValue(); + balancedCounts.merge(classValue, 1, Integer::sum); + } - data.add(new DenseInstance(1.0, values)); + ClassifierLogger.log("Class distribution after balancing:"); + balancedCounts.forEach((classVal, count) -> { + String className = data.classAttribute().value(classVal.intValue()); + ClassifierLogger.log(String.format(" %s: %d (%.1f%%)", + className, count, (double) count * 100 / data.numInstances())); + }); } - private double[] addValuesToDataset(Event event, int month, Instances tempData) { - double[] values = new double[tempData.numAttributes()]; + private void trainClassifier(Instances data) throws Exception { + classifier = new NaiveBayes(); + classifier.setUseKernelEstimator(false); + classifier.setUseSupervisedDiscretization(true); + classifier.buildClassifier(data); + trainingHeader = new Instances(data, 0); + } + + private void evaluateModel(Instances data) throws Exception { + Evaluation eval = new Evaluation(data); + eval.crossValidateModel(classifier, data, 10, new Random(RANDOM_SEED)); + + ClassifierLogger.log("=== Classification Results ==="); + ClassifierLogger.log(eval.toSummaryString()); + + ClassifierLogger.log("\n=== Detailed Accuracy By Class ==="); + ClassifierLogger.log(eval.toClassDetailsString()); + + ClassifierLogger.log("\n=== Confusion Matrix ==="); + ClassifierLogger.log(eval.toMatrixString()); + } + + private ArrayList createAttributes() { + ArrayList attributes = new ArrayList<>(); + + // Text attributes + attributes.add(new Attribute("title", (List) null)); + attributes.add(new Attribute("description", (List) null)); + attributes.add(new Attribute("hostName", (List) null)); + attributes.add(new Attribute("location", (List) null)); + + // Numeric attributes + attributes.add(new Attribute("eventCapacity")); + attributes.add(new Attribute("eventMonth")); + attributes.add(new Attribute("eventDayOfWeek")); + attributes.add(new Attribute("isWeekend")); + + // Enhanced features + attributes.add(new Attribute("titleLength")); + attributes.add(new Attribute("descriptionWordCount")); + attributes.add(new Attribute("hasWorkshopInTitle")); // Binary + attributes.add(new Attribute("hasWorkshopInDesc")); // Binary + attributes.add(new Attribute("hasHandsOnInDesc")); // Binary + + return attributes; + } + + private void addInstance(Instances data, Event event) { + double[] values = new double[data.numAttributes()]; // Basic features - values[0] = tempData.attribute(0).addStringValue(Objects.toString(event.getTitle(), "")); - values[1] = tempData.attribute(1).addStringValue(Objects.toString(event.getDescription(), "")); - values[2] = tempData.attribute(2).addStringValue(Objects.toString(event.getEventHost(), "")); - values[3] = tempData.attribute(3).addStringValue(Objects.toString(event.getEventLocation(), "")); + values[0] = data.attribute(0).addStringValue(preprocessText(event.getTitle())); + values[1] = data.attribute(1).addStringValue(preprocessText(event.getDescription())); + values[2] = data.attribute(2).addStringValue(preprocessText(event.getEventHost())); + values[3] = data.attribute(3).addStringValue(preprocessText(event.getEventLocation())); values[4] = event.getEventCapacity(); // Temporal features Calendar cal = Calendar.getInstance(); cal.setTime(event.getEventDate()); - values[5] = month; + values[5] = cal.get(Calendar.MONTH) + 1; values[6] = cal.get(Calendar.DAY_OF_WEEK); - values[7] = (values[6] == 1 || values[6] == 7) ? 1 : 0; // isWeekend + values[7] = (values[6] == Calendar.SUNDAY || values[6] == Calendar.SATURDAY) ? 1 : 0; - // Textual features - String title = Objects.toString(event.getTitle(), ""); - String description = Objects.toString(event.getDescription(), ""); + // Text features + String title = event.getTitle() != null ? event.getTitle().toLowerCase() : ""; + String desc = event.getDescription() != null ? event.getDescription().toLowerCase() : ""; values[8] = title.length(); - values[9] = description.split("\\s+").length; - - return values; - } + values[9] = desc.split("\\s+").length; - private Instances createFilteredInstance(Event event) throws Exception { - ArrayList originalAttributes = createAttributes(); - originalAttributes.add(trainingHeader.classAttribute()); - - Calendar cal = Calendar.getInstance(); - cal.setTime(event.getEventDate()); - int month = cal.get(Calendar.MONTH) + 1; - - Instances tempData = new Instances("TempData", originalAttributes, 0); - tempData.setClassIndex(tempData.numAttributes() - 1); + // Workshop-specific features + values[10] = title.contains("workshop") ? 1 : 0; + values[11] = desc.contains("workshop") ? 1 : 0; + values[12] = desc.contains("hands-on") || desc.contains("practical") ? 1 : 0; - double[] values = addValuesToDataset(event, month, tempData); - values[6] = 0; + // Class value + values[data.classIndex()] = data.classAttribute().indexOfValue(event.getEventType()); - tempData.add(new DenseInstance(1.0, values)); - return applyStringToWordVectorFilter(tempData); + data.add(new DenseInstance(1.0, values)); } private String preprocessText(String text) { if (text == null) return ""; - - return text.toLowerCase() - .replaceAll("\\s+", " ") - .trim(); + return text.toLowerCase().replaceAll("\\s+", " ").trim(); } - public String predictEventType(Event event) { try { - Instance instance = new DenseInstance(trainingHeader.numAttributes()); - instance.setDataset(trainingHeader); - Instances filteredData = createFilteredInstance(event); - double prediction = classifier.classifyInstance(filteredData.firstInstance()); return trainingHeader.classAttribute().value((int) prediction); - } catch (Exception e) { throw new RuntimeException("Failed to make prediction: " + e.getMessage()); } @@ -508,78 +343,75 @@ public String predictEventType(Event event) { public Map getPredictionProbabilities(Event event) { try { - ArrayList originalAttributes = createAttributes(); - originalAttributes.add(trainingHeader.classAttribute()); - Instances filteredData = createFilteredInstance(event); - double[] distributions = classifier.distributionForInstance(filteredData.firstInstance()); - Map probabilities = new HashMap<>(); + Map probabilities = new HashMap<>(); for (int i = 0; i < distributions.length; i++) { String type = trainingHeader.classAttribute().value(i); probabilities.put(type, distributions[i]); } - return probabilities; - } catch (Exception e) { throw new RuntimeException("Failed to get prediction probabilities: " + e.getMessage()); } } - // Add these to your EventTypeClassifier class - public void saveModel(String filePath) throws Exception { - SerializationHelper.write(filePath, classifier); - } + private Instances createFilteredInstance(Event event) throws Exception { + ArrayList originalAttributes = createAttributes(); + originalAttributes.add(trainingHeader.classAttribute()); - public static EventTypeClassifier loadModel(String filePath) throws Exception { - EventTypeClassifier loaded = new EventTypeClassifier(); - loaded.classifier = (NaiveBayes) SerializationHelper.read(filePath); - return loaded; - } + Calendar cal = Calendar.getInstance(); + cal.setTime(event.getEventDate()); + int month = cal.get(Calendar.MONTH) + 1; - public Instances getTrainingData() { - return new Instances(trainingHeader); - } + Instances tempData = new Instances("TempData", originalAttributes, 0); + tempData.setClassIndex(tempData.numAttributes() - 1); - // Add methods needed by ImprovedEventClassifierTester + double[] values = addValuesToDataset(event, month, tempData); + values[6] = 0; // Reset day of week for prediction - public Classifier getClassifier() { - return classifier; + tempData.add(new DenseInstance(1.0, values)); + return applyTextProcessingFilters(tempData); } + private double[] addValuesToDataset(Event event, int month, Instances tempData) { + double[] values = new double[tempData.numAttributes()]; - /** - * Converts a list of events to Weka Instances for cross-validation - */ - public Instances convertEventsToInstances(List events) throws Exception { - // Filter valid events - List validEvents = events.stream() - .filter(e -> e.getEventType() != null && !e.getEventType().isEmpty()) - .toList(); + // Basic features + values[0] = tempData.attribute(0).addStringValue(Objects.toString(event.getTitle(), "")); + values[1] = tempData.attribute(1).addStringValue(Objects.toString(event.getDescription(), "")); + values[2] = tempData.attribute(2).addStringValue(Objects.toString(event.getEventHost(), "")); + values[3] = tempData.attribute(3).addStringValue(Objects.toString(event.getEventLocation(), "")); + values[4] = event.getEventCapacity(); - // Get unique event types + // Temporal features + Calendar cal = Calendar.getInstance(); + cal.setTime(event.getEventDate()); + values[5] = month; + values[6] = cal.get(Calendar.DAY_OF_WEEK); + values[7] = (values[6] == 1 || values[6] == 7) ? 1 : 0; // isWeekend - ArrayList eventTypes = validEvents.stream() - .map(Event::getEventType) - .filter(Objects::nonNull).distinct().collect(Collectors.toCollection(ArrayList::new)); - ArrayList attributes = createAttributes(); + // Textual features + String title = Objects.toString(event.getTitle(), ""); + String description = Objects.toString(event.getDescription(), ""); + values[8] = title.length(); + values[9] = description.split("\\s+").length; - // Add class attribute - Attribute classAttribute = new Attribute("class_event_type", eventTypes); - attributes.add(classAttribute); + return values; + } - // Create dataset - Instances data = new Instances("EventData", attributes, 0); - data.setClassIndex(data.numAttributes() - 1); + public void saveModel(String filePath) throws Exception { + SerializationHelper.write(filePath, classifier); + } - // Add instances - for (Event event : validEvents) { - addInstance(data, event); - } + public static EventTypeClassifier loadModel(String filePath) throws Exception { + EventTypeClassifier loaded = new EventTypeClassifier(); + loaded.classifier = (NaiveBayes) SerializationHelper.read(filePath); + return loaded; + } - // Apply filtering - return applyStringToWordVectorFilter(data); + public Instances getTrainingData() { + return new Instances(trainingHeader); } } \ No newline at end of file diff --git a/emmas/src/main/java/dev/brianweloba/weka/ImprovedEventClassifierTester.java b/emmas/src/main/java/dev/brianweloba/weka/ImprovedEventClassifierTester.java deleted file mode 100644 index 3df02a5..0000000 --- a/emmas/src/main/java/dev/brianweloba/weka/ImprovedEventClassifierTester.java +++ /dev/null @@ -1,319 +0,0 @@ -package dev.brianweloba.weka; - -import dev.brianweloba.model.Event; -import weka.classifiers.Evaluation; -import weka.classifiers.trees.RandomForest; -import weka.core.Instances; -import weka.core.converters.ArffSaver; - -import java.io.File; -import java.io.IOException; -import java.io.PrintWriter; -import java.util.*; - -public class ImprovedEventClassifierTester { - private final EventTypeClassifier classifier; - private final String outputDir; - - public ImprovedEventClassifierTester(String outputDir) { - this.classifier = new EventTypeClassifier(); - this.outputDir = outputDir; - new File(outputDir).mkdirs(); // Ensure directory exists - } - - public void runFullTest(List allEvents) throws Exception { - ClassifierLogger.log("Starting improved testing process with balanced evaluation"); - - // Ensure we have enough data - if (allEvents.size() < 50) { - ClassifierLogger.log("Warning: Small dataset size (" + allEvents.size() + " events)"); - } - - // Check class distribution before processing - Map typeCounts = new HashMap<>(); - for (Event event : allEvents) { - if (event.getEventType() != null) { - typeCounts.merge(event.getEventType(), 1, Integer::sum); - } - } - - // Log distribution - ClassifierLogger.log("Class distribution in full dataset:"); - typeCounts.forEach((type, count) -> - ClassifierLogger.log(String.format(" %s: %d (%.1f%%)", - type, count, (double)count * 100 / allEvents.size()))); - - // Split data (70% train, 30% test for better test set size) - Collections.shuffle(allEvents, new Random(42)); // Shuffle with fixed seed for reproducibility - int splitPoint = (int) (allEvents.size() * 0.7); - List trainingEvents = allEvents.subList(0, splitPoint); - List testEvents = allEvents.subList(splitPoint, allEvents.size()); - - // Train and save model - ClassifierLogger.log("Starting training with " + trainingEvents.size() + " events"); - classifier.trainModel(trainingEvents); - saveModelAndData(trainingEvents, "training"); - - // Test and save results - ClassifierLogger.log("Starting testing with " + testEvents.size() + " events"); - TestResults results = evaluateOnTestSet(testEvents); - - // Calculate per-class metrics - Map classMetrics = calculatePerClassMetrics(testEvents); - - // Save detailed results - saveTestResults(results, testEvents, classMetrics); - - // Log metrics - ClassifierLogger.logMetrics( - results.accuracy, results.precision, - results.recall, results.f1, - trainingEvents.size(), testEvents.size() - ); - - // Perform cross-validation for more robust evaluation - performCrossValidation(allEvents,5); - } - - // In the performCrossValidation method, create a new classifier for each fold - private void performCrossValidation(List allEvents, int folds) { - try { - // Generate a stratified sample of the data - Random rand = new Random(42); - Collections.shuffle(allEvents, rand); - - // Create CV partitions manually - int foldSize = allEvents.size() / folds; - double totalAccuracy = 0; - - for (int i = 0; i < folds; i++) { - // Split into training and testing for this fold - List cvTest = new ArrayList<>(allEvents.subList(i * foldSize, (i + 1) * foldSize)); - List cvTrain = new ArrayList<>(allEvents); - cvTrain.removeAll(cvTest); - - // Create a fresh classifier for this fold - EventTypeClassifier cvClassifier = new EventTypeClassifier(); - - // Train it on this fold's training data - cvClassifier.trainModel(cvTrain); - - // Test on this fold's test data - int correct = 0; - for (Event event : cvTest) { - String predicted = cvClassifier.predictEventType(event); - if (predicted.equals(event.getEventType())) { - correct++; - } - } - - double foldAccuracy = (double) correct / cvTest.size(); - totalAccuracy += foldAccuracy; - - ClassifierLogger.log(String.format("Fold %d accuracy: %.2f%%", i+1, foldAccuracy * 100)); - } - - double avgAccuracy = totalAccuracy / folds; - ClassifierLogger.log(String.format("Cross-validation average accuracy: %.2f%%", avgAccuracy * 100)); - - } catch (Exception e) { - ClassifierLogger.log("Error during cross-validation: " + e.getMessage()); - e.printStackTrace(); - } - } - - private void saveConfusionMatrix(Evaluation eval, Instances data) { - try (PrintWriter writer = new PrintWriter(new File(outputDir + "/confusion_matrix.txt"))) { - writer.println("=== Confusion Matrix ==="); - writer.println(eval.toMatrixString()); - - writer.println("\n=== Detailed Accuracy By Class ==="); - writer.println(eval.toClassDetailsString()); - } catch (Exception e) { - ClassifierLogger.log("Error saving confusion matrix: " + e.getMessage()); - } - } - - private Map calculatePerClassMetrics(List testEvents) { - Map metrics = new HashMap<>(); - Map truePositives = new HashMap<>(); - Map falsePositives = new HashMap<>(); - Map falseNegatives = new HashMap<>(); - Map totalActual = new HashMap<>(); - - // Count actual class distribution - for (Event event : testEvents) { - String actual = event.getEventType(); - totalActual.merge(actual, 1, Integer::sum); - } - - // Calculate metrics - for (Event event : testEvents) { - String actual = event.getEventType(); - String predicted = classifier.predictEventType(event); - - if (actual.equals(predicted)) { - truePositives.merge(actual, 1, Integer::sum); - } else { - falseNegatives.merge(actual, 1, Integer::sum); - falsePositives.merge(predicted, 1, Integer::sum); - } - } - - // Calculate precision, recall, and F1 for each class - for (String className : totalActual.keySet()) { - int tp = truePositives.getOrDefault(className, 0); - int fp = falsePositives.getOrDefault(className, 0); - int fn = falseNegatives.getOrDefault(className, 0); - int total = totalActual.get(className); - - double precision = (tp + fp == 0) ? 0 : (double) tp / (tp + fp); - double recall = (double) tp / total; - double f1 = (precision + recall == 0) ? 0 : 2 * precision * recall / (precision + recall); - - metrics.put(className, new ClassMetrics(precision, recall, f1, total)); - } - - return metrics; - } - - private TestResults evaluateOnTestSet(List testEvents) { - TestResults results = new TestResults(); - int correct = 0; - Map classCounts = new HashMap<>(); - Map correctCounts = new HashMap<>(); - - try (PrintWriter writer = new PrintWriter(new File(outputDir + "/detailed_results.csv"))) { - writer.println("actual,predicted,confidence,title,is_correct"); - - for (Event event : testEvents) { - String actual = event.getEventType(); - String predicted = classifier.predictEventType(event); - Map probs = classifier.getPredictionProbabilities(event); - double confidence = probs.getOrDefault(predicted, 0.0); - boolean isCorrect = actual.equals(predicted); - - writer.printf("%s,%s,%.4f,%s,%b%n", - actual, predicted, confidence, - event.getTitle().replace(",", ""), isCorrect); - - // Update class statistics - classCounts.merge(actual, 1, Integer::sum); - if (isCorrect) { - correct++; - correctCounts.merge(actual, 1, Integer::sum); - } else { - ClassifierLogger.log(String.format( - "MISCLASSIFIED: Actual=%s, Predicted=%s (%.2f%%) for '%s'", - actual, predicted, confidence*100, event.getTitle() - )); - } - } - - // Calculate global metrics - results.accuracy = (double) correct / testEvents.size(); - - // Calculate confusion matrix-based metrics - double totalPrecision = 0; - double totalRecall = 0; - int classCount = classCounts.size(); - - for (String className : classCounts.keySet()) { - int truePositives = correctCounts.getOrDefault(className, 0); - int total = classCounts.get(className); - - // Calculate predicted count (for precision) - long predictedCount = testEvents.stream() - .filter(e -> classifier.predictEventType(e).equals(className)) - .count(); - - double precision = predictedCount == 0 ? 0 : (double) truePositives / predictedCount; - double recall = (double) truePositives / total; - - totalPrecision += precision; - totalRecall += recall; - } - - // Calculate macro-averaged metrics - results.precision = totalPrecision / classCount; - results.recall = totalRecall / classCount; - results.f1 = (results.precision + results.recall == 0) ? 0 : - 2 * results.precision * results.recall / (results.precision + results.recall); - - } catch (Exception e) { - ClassifierLogger.log("Error during testing: " + e.getMessage()); - } - - return results; - } - - private void saveModelAndData(List events, String prefix) throws Exception { - // Save ARFF file for inspection - Instances data = classifier.getTrainingData(); // You'll need to add this getter - ArffSaver saver = new ArffSaver(); - saver.setInstances(data); - saver.setFile(new File(outputDir + "/" + prefix + "_data.arff")); - saver.writeBatch(); - - // Save model (you'll need to add model serialization to your classifier) - classifier.saveModel(outputDir + "/" + prefix + "_model.model"); - } - - private void saveTestResults(TestResults results, List testEvents, - Map classMetrics) { - try (PrintWriter writer = new PrintWriter(new File(outputDir + "/summary_results.txt"))) { - writer.println("=== Classifier Test Results ==="); - writer.printf("Accuracy: %.2f%%%n", results.accuracy * 100); - writer.printf("Macro-Averaged Precision: %.4f%n", results.precision); - writer.printf("Macro-Averaged Recall: %.4f%n", results.recall); - writer.printf("Macro-Averaged F1 Score: %.4f%n", results.f1); - - writer.println("\n=== Per-Class Performance ==="); - for (Map.Entry entry : classMetrics.entrySet()) { - ClassMetrics metrics = entry.getValue(); - writer.printf("%s (n=%d):%n", entry.getKey(), metrics.count); - writer.printf(" Precision: %.4f%n", metrics.precision); - writer.printf(" Recall: %.4f%n", metrics.recall); - writer.printf(" F1 Score: %.4f%n", metrics.f1); - } - - writer.println("\n=== Test Set Distribution ==="); - - // Count actual class distribution - Map actualCounts = new HashMap<>(); - for (Event e : testEvents) { - actualCounts.merge(e.getEventType(), 1, Integer::sum); - } - - for (Map.Entry entry : actualCounts.entrySet()) { - writer.printf("%s: %d (%.1f%%)%n", - entry.getKey(), entry.getValue(), - (double) entry.getValue() / testEvents.size() * 100); - } - - } catch (IOException e) { - ClassifierLogger.log("Error saving test results: " + e.getMessage()); - } - } - - private static class TestResults { - double accuracy; - double precision; - double recall; - double f1; - } - - private static class ClassMetrics { - double precision; - double recall; - double f1; - int count; - - public ClassMetrics(double precision, double recall, double f1, int count) { - this.precision = precision; - this.recall = recall; - this.f1 = f1; - this.count = count; - } - } -} \ No newline at end of file diff --git a/emmas/src/main/java/dev/brianweloba/weka/RunTest.java b/emmas/src/main/java/dev/brianweloba/weka/RunTest.java index b2c99a2..e1ff1e3 100644 --- a/emmas/src/main/java/dev/brianweloba/weka/RunTest.java +++ b/emmas/src/main/java/dev/brianweloba/weka/RunTest.java @@ -3,18 +3,15 @@ import dev.brianweloba.dao.EventDAO; import dev.brianweloba.model.Event; -import java.util.ArrayList; import java.util.List; public class RunTest { public static void main(String[] args) { try { List allEvents = loadAllEvents(); - - ImprovedEventClassifierTester tester = new ImprovedEventClassifierTester("improved_classifier_results"); + EventClassifierTester tester = new EventClassifierTester("improved_classifier_results"); tester.runFullTest(allEvents); - ClassifierLogger.log("Testing completed successfully"); } catch (Exception e) { ClassifierLogger.log("Fatal error in testing: " + e.getMessage()); From 81609bd01828552c414440d46ebd059671aab6e9 Mon Sep 17 00:00:00 2001 From: Brian-Weloba Date: Mon, 7 Apr 2025 09:23:33 +0100 Subject: [PATCH 5/7] Enhance event classification; refactor EventClassifierTester for improved logging, class distribution summary, and cross-validation metrics. Update EventTypeClassifier to utilize AdaBoost for better model performance and parameter optimization. --- .../weka/EventClassifierTester.java | 131 +++++++++++------- .../brianweloba/weka/EventTypeClassifier.java | 123 ++++++++++++++-- 2 files changed, 191 insertions(+), 63 deletions(-) diff --git a/emmas/src/main/java/dev/brianweloba/weka/EventClassifierTester.java b/emmas/src/main/java/dev/brianweloba/weka/EventClassifierTester.java index 11ec270..00162af 100644 --- a/emmas/src/main/java/dev/brianweloba/weka/EventClassifierTester.java +++ b/emmas/src/main/java/dev/brianweloba/weka/EventClassifierTester.java @@ -13,6 +13,7 @@ public class EventClassifierTester { private final EventTypeClassifier classifier; private final String outputDir; + private static final boolean VERBOSE_LOGGING = true; // Toggle for detailed logs public EventClassifierTester(String outputDir) { this.classifier = new EventTypeClassifier(); @@ -21,40 +22,29 @@ public EventClassifierTester(String outputDir) { } public void runFullTest(List allEvents) throws Exception { - ClassifierLogger.log("Starting improved testing process with balanced evaluation"); + ClassifierLogger.log("Starting event classification evaluation"); // Ensure we have enough data if (allEvents.size() < 50) { ClassifierLogger.log("Warning: Small dataset size (" + allEvents.size() + " events)"); } - // Check class distribution before processing - Map typeCounts = new HashMap<>(); - for (Event event : allEvents) { - if (event.getEventType() != null) { - typeCounts.merge(event.getEventType(), 1, Integer::sum); - } - } - - // Log distribution - ClassifierLogger.log("Class distribution in full dataset:"); - typeCounts.forEach((type, count) -> - ClassifierLogger.log(String.format(" %s: %d (%.1f%%)", - type, count, (double)count * 100 / allEvents.size()))); + // Quick summary of class distribution + logClassDistributionSummary(allEvents); - // Split data (70% train, 30% test for better test set size) - Collections.shuffle(allEvents, new Random(42)); // Shuffle with fixed seed for reproducibility + // Split data (70% train, 30% test) + Collections.shuffle(allEvents, new Random(42)); // Fixed seed for reproducibility int splitPoint = (int) (allEvents.size() * 0.7); List trainingEvents = allEvents.subList(0, splitPoint); List testEvents = allEvents.subList(splitPoint, allEvents.size()); - // Train and save model - ClassifierLogger.log("Starting training with " + trainingEvents.size() + " events"); + // Train model + ClassifierLogger.log("Training with " + trainingEvents.size() + " events"); classifier.trainModel(trainingEvents); saveModelAndData(trainingEvents, "training"); // Test and save results - ClassifierLogger.log("Starting testing with " + testEvents.size() + " events"); + ClassifierLogger.log("Testing with " + testEvents.size() + " events"); TestResults results = evaluateOnTestSet(testEvents); // Calculate per-class metrics @@ -63,41 +53,58 @@ public void runFullTest(List allEvents) throws Exception { // Save detailed results saveTestResults(results, testEvents, classMetrics); - // Log metrics - ClassifierLogger.logMetrics( - results.accuracy, results.precision, - results.recall, results.f1, - trainingEvents.size(), testEvents.size() - ); + // Log summary metrics + logSummaryMetrics(results); - // Perform cross-validation for more robust evaluation - performCrossValidation(allEvents,10); + // Perform cross-validation + performCrossValidation(allEvents, 5); + } + + private void logClassDistributionSummary(List events) { + Map typeCounts = new HashMap<>(); + for (Event event : events) { + if (event.getEventType() != null) { + typeCounts.merge(event.getEventType(), 1, Integer::sum); + } + } + + ClassifierLogger.log("Class distribution: " + + typeCounts.entrySet().stream() + .sorted(Map.Entry.comparingByValue().reversed()) + .map(e -> String.format("%s=%d(%.1f%%)", + e.getKey(), e.getValue(), (double)e.getValue()*100/events.size())) + .reduce((a, b) -> a + ", " + b) + .orElse("")); + } + + private void logSummaryMetrics(TestResults results) { + ClassifierLogger.log(String.format( + "Results: Accuracy=%.1f%%, Precision=%.3f, Recall=%.3f, F1=%.3f", + results.accuracy * 100, results.precision, results.recall, results.f1 + )); } - // In the performCrossValidation method, create a new classifier for each fold private void performCrossValidation(List allEvents, int folds) { try { - // Generate a stratified sample of the data Random rand = new Random(42); Collections.shuffle(allEvents, rand); - // Create CV partitions manually int foldSize = allEvents.size() / folds; double totalAccuracy = 0; + double[] foldAccuracies = new double[folds]; for (int i = 0; i < folds; i++) { - // Split into training and testing for this fold - List cvTest = new ArrayList<>(allEvents.subList(i * foldSize, (i + 1) * foldSize)); + // Split for this fold + List cvTest = new ArrayList<>(allEvents.subList(i * foldSize, + Math.min((i + 1) * foldSize, allEvents.size()))); List cvTrain = new ArrayList<>(allEvents); cvTrain.removeAll(cvTest); - // Create a fresh classifier for this fold + // Train EventTypeClassifier cvClassifier = new EventTypeClassifier(); - - // Train it on this fold's training data cvClassifier.trainModel(cvTrain); - // Test on this fold's test data + // Test int correct = 0; for (Event event : cvTest) { String predicted = cvClassifier.predictEventType(event); @@ -106,18 +113,25 @@ private void performCrossValidation(List allEvents, int folds) { } } - double foldAccuracy = (double) correct / cvTest.size(); - totalAccuracy += foldAccuracy; + foldAccuracies[i] = (double) correct / cvTest.size(); + totalAccuracy += foldAccuracies[i]; - ClassifierLogger.log(String.format("Fold %d accuracy: %.2f%%", i+1, foldAccuracy * 100)); + if (VERBOSE_LOGGING) { + ClassifierLogger.log(String.format("Fold %d accuracy: %.1f%%", + i+1, foldAccuracies[i] * 100)); + } } + // Report the summary with min/max/avg double avgAccuracy = totalAccuracy / folds; - ClassifierLogger.log(String.format("Cross-validation average accuracy: %.2f%%", avgAccuracy * 100)); + double minAccuracy = Arrays.stream(foldAccuracies).min().orElse(0); + double maxAccuracy = Arrays.stream(foldAccuracies).max().orElse(0); + + ClassifierLogger.log(String.format("CV results: Avg=%.1f%%, Min=%.1f%%, Max=%.1f%% (%d folds)", + avgAccuracy * 100, minAccuracy * 100, maxAccuracy * 100, folds)); } catch (Exception e) { - ClassifierLogger.log("Error during cross-validation: " + e.getMessage()); - e.printStackTrace(); + ClassifierLogger.log("CV error: " + e.getMessage()); } } @@ -125,7 +139,6 @@ private void saveConfusionMatrix(Evaluation eval, Instances data) { try (PrintWriter writer = new PrintWriter(new File(outputDir + "/confusion_matrix.txt"))) { writer.println("=== Confusion Matrix ==="); writer.println(eval.toMatrixString()); - writer.println("\n=== Detailed Accuracy By Class ==="); writer.println(eval.toClassDetailsString()); } catch (Exception e) { @@ -181,6 +194,7 @@ private TestResults evaluateOnTestSet(List testEvents) { int correct = 0; Map classCounts = new HashMap<>(); Map correctCounts = new HashMap<>(); + List misclassifications = new ArrayList<>(); try (PrintWriter writer = new PrintWriter(new File(outputDir + "/detailed_results.csv"))) { writer.println("actual,predicted,confidence,title,is_correct"); @@ -202,10 +216,10 @@ private TestResults evaluateOnTestSet(List testEvents) { correct++; correctCounts.merge(actual, 1, Integer::sum); } else { - ClassifierLogger.log(String.format( - "MISCLASSIFIED: Actual=%s, Predicted=%s (%.2f%%) for '%s'", - actual, predicted, confidence*100, event.getTitle() - )); + // Store misclassifications instead of logging each one + misclassifications.add(String.format( + "Actual=%s, Predicted=%s (%.0f%%) for '%s'", + actual, predicted, confidence*100, event.getTitle())); } } @@ -239,6 +253,25 @@ private TestResults evaluateOnTestSet(List testEvents) { results.f1 = (results.precision + results.recall == 0) ? 0 : 2 * results.precision * results.recall / (results.precision + results.recall); + // Log misclassification summary + int misclassificationCount = misclassifications.size(); + if (misclassificationCount > 0) { + ClassifierLogger.log(String.format("Misclassifications: %d/%d (%.1f%%)", + misclassificationCount, testEvents.size(), + (double)misclassificationCount*100/testEvents.size())); + + // Only log the first few misclassifications if verbose + if (VERBOSE_LOGGING) { + int limit = Math.min(5, misclassificationCount); + for (int i = 0; i < limit; i++) { + ClassifierLogger.log("MISCLASSIFIED: " + misclassifications.get(i)); + } + if (misclassificationCount > limit) { + ClassifierLogger.log("... and " + (misclassificationCount - limit) + " more"); + } + } + } + } catch (Exception e) { ClassifierLogger.log("Error during testing: " + e.getMessage()); } @@ -248,13 +281,13 @@ private TestResults evaluateOnTestSet(List testEvents) { private void saveModelAndData(List events, String prefix) throws Exception { // Save ARFF file for inspection - Instances data = classifier.getTrainingData(); // You'll need to add this getter + Instances data = classifier.getTrainingData(); ArffSaver saver = new ArffSaver(); saver.setInstances(data); saver.setFile(new File(outputDir + "/" + prefix + "_data.arff")); saver.writeBatch(); - // Save model (you'll need to add model serialization to your classifier) + // Save model classifier.saveModel(outputDir + "/" + prefix + "_model.model"); } diff --git a/emmas/src/main/java/dev/brianweloba/weka/EventTypeClassifier.java b/emmas/src/main/java/dev/brianweloba/weka/EventTypeClassifier.java index 25572d1..26435cd 100644 --- a/emmas/src/main/java/dev/brianweloba/weka/EventTypeClassifier.java +++ b/emmas/src/main/java/dev/brianweloba/weka/EventTypeClassifier.java @@ -3,12 +3,14 @@ import dev.brianweloba.model.Event; import weka.classifiers.bayes.NaiveBayes; import weka.classifiers.evaluation.Evaluation; +import weka.classifiers.meta.AdaBoostM1; import weka.core.*; import weka.core.stopwords.Rainbow; import weka.core.tokenizers.NGramTokenizer; import weka.filters.Filter; import weka.filters.supervised.instance.SMOTE; import weka.filters.unsupervised.attribute.StringToWordVector; +import weka.classifiers.meta.CVParameterSelection; import java.util.*; import java.util.stream.Collectors; @@ -21,9 +23,10 @@ public class EventTypeClassifier { private static final int NGRAM_MIN_SIZE = 1; private static final int NGRAM_MAX_SIZE = 2; private static final int WORDS_TO_KEEP = 500; - private static final long RANDOM_SEED = 42L; + private static final int RANDOM_SEED = 42; - private NaiveBayes classifier; + // private NaiveBayes classifier; + private AdaBoostM1 classifier; private Instances trainingHeader; public void trainModel(List events) { @@ -38,9 +41,10 @@ public void trainModel(List events) { trainClassifier(balancedData); evaluateModel(balancedData); - ClassifierLogger.log("NaiveBayes model trained successfully with " + - balancedData.numInstances() + " balanced instances and " + - balancedData.numAttributes() + " features"); + // Add this after model training: + ClassifierLogger.log("\n=== AdaBoost Configuration ==="); + ClassifierLogger.log("Iterations: " + classifier.getNumIterations()); + ClassifierLogger.log("Using classifier: " + classifier.getClassifier().getClass().getSimpleName()); } catch (Exception e) { throw new RuntimeException("Failed to train model: " + e.getMessage(), e); } @@ -245,11 +249,108 @@ private void logBalancedDistribution(Instances data) { }); } + private void configureAdaBoost(AdaBoostM1 adaBoost) { + adaBoost.setDebug(false); + adaBoost.setNumIterations(15); // Can experiment with 10-20 + adaBoost.setWeightThreshold(50); + adaBoost.setUseResampling(false); + adaBoost.setResume(false); + } + +// private void trainClassifier(Instances data) throws Exception { +// // 1. Create base classifier +// NaiveBayes nb = new NaiveBayes(); +// nb.setUseKernelEstimator(false); +// nb.setUseSupervisedDiscretization(true); +// +// // 2. Create AdaBoost with default parameters +// AdaBoostM1 baseAdaBoost = new AdaBoostM1(); +// baseAdaBoost.setClassifier(nb); +// baseAdaBoost.setSeed(RANDOM_SEED); +// +// // 3. Setup parameter search +// CVParameterSelection ps = new CVParameterSelection(); +// ps.setClassifier(baseAdaBoost); +// ps.setNumFolds(5); // 5-fold cross-validation +// +// // Fix: Correct parameter flag for iterations (should be "I" not "P") +// ps.addCVParameter("I 10 20 10"); // Tests 10, 15, 20 iterations +// +// // Optional: Search weightThreshold too +// // ps.addCVParameter("W 50 150 50"); // Tests 50, 100, 150 +// +// // 4. Run optimization +// ps.buildClassifier(data); +// +// // 5. Apply best parameters manually +// classifier = new AdaBoostM1(); +// classifier.setClassifier(nb); +// +// // Get best options and apply them +// String[] bestOptions = ps.getBestClassifierOptions(); +// classifier.setOptions(bestOptions); +// +// // Fix: Train the classifier with the optimized parameters +// classifier.buildClassifier(data); +// +// // 6. Log the selected parameters +// ClassifierLogger.log("\n=== Optimized AdaBoost Parameters ==="); +// ClassifierLogger.log("Selected iterations: " + classifier.getNumIterations()); +// ClassifierLogger.log("Best configuration: " + Arrays.toString(bestOptions)); +// +// trainingHeader = new Instances(data, 0); +// } + private void trainClassifier(Instances data) throws Exception { - classifier = new NaiveBayes(); - classifier.setUseKernelEstimator(false); - classifier.setUseSupervisedDiscretization(true); + // 1. Create base classifier + NaiveBayes nb = new NaiveBayes(); + nb.setUseKernelEstimator(false); + nb.setUseSupervisedDiscretization(true); + + // 2. Perform manual parameter search for AdaBoost + int[] iterationsToTest = {5, 10, 15, 20, 25}; // The iterations we want to test + double bestAccuracy = 0.0; + int bestIterations = 10; // Default value + + ClassifierLogger.log("\n=== Manual Parameter Search ==="); + + for (int iterations : iterationsToTest) { + // Create and configure a fresh AdaBoost instance for each test + AdaBoostM1 testBoost = new AdaBoostM1(); + testBoost.setClassifier(nb); + testBoost.setSeed(RANDOM_SEED); + testBoost.setNumIterations(iterations); + + // Evaluate using cross-validation + Evaluation eval = new Evaluation(data); + eval.crossValidateModel(testBoost, data, 5, new Random(RANDOM_SEED)); + + double accuracy = eval.pctCorrect(); + ClassifierLogger.log(String.format(" Iterations=%d, Accuracy=%.2f%%", + iterations, accuracy)); + + // Keep track of the best configuration + if (accuracy > bestAccuracy) { + bestAccuracy = accuracy; + bestIterations = iterations; + } + } + + // 3. Create the final classifier with the best parameters + classifier = new AdaBoostM1(); + classifier.setClassifier(nb); + classifier.setSeed(RANDOM_SEED); + classifier.setNumIterations(bestIterations); + + // Train the final model classifier.buildClassifier(data); + + // 4. Log the selected parameters + ClassifierLogger.log("\n=== Optimized AdaBoost Parameters ==="); + ClassifierLogger.log("Selected iterations: " + bestIterations); + ClassifierLogger.log("Best accuracy: " + String.format("%.2f%%", bestAccuracy)); + ClassifierLogger.log("Using classifier: " + classifier.getClassifier().getClass().getSimpleName()); + trainingHeader = new Instances(data, 0); } @@ -405,12 +506,6 @@ public void saveModel(String filePath) throws Exception { SerializationHelper.write(filePath, classifier); } - public static EventTypeClassifier loadModel(String filePath) throws Exception { - EventTypeClassifier loaded = new EventTypeClassifier(); - loaded.classifier = (NaiveBayes) SerializationHelper.read(filePath); - return loaded; - } - public Instances getTrainingData() { return new Instances(trainingHeader); } From 2142a0d9ef19f47c331e39ef1c7e54e2687385af Mon Sep 17 00:00:00 2001 From: Brian-Weloba Date: Thu, 10 Apr 2025 11:46:04 +0100 Subject: [PATCH 6/7] Add attendee management functionality; implement attendees.jsp for displaying event RSVPs, enhance RSVPServlet to fetch attendees, and update styles for improved UI. --- .../java/dev/brianweloba/dao/RsvpDAO.java | 10 ++ .../main/java/dev/brianweloba/model/User.java | 58 ++++++++ .../dev/brianweloba/servlet/RSVPServlet.java | 39 ++++++ .../main/webapp/WEB-INF/views/attendees.jsp | 59 ++++++++ .../main/webapp/WEB-INF/views/eventsBody.jsp | 128 ++++++++++++++---- emmas/src/main/webapp/css/styles.css | 66 +++++++++ 6 files changed, 337 insertions(+), 23 deletions(-) create mode 100644 emmas/src/main/java/dev/brianweloba/model/User.java create mode 100644 emmas/src/main/webapp/WEB-INF/views/attendees.jsp diff --git a/emmas/src/main/java/dev/brianweloba/dao/RsvpDAO.java b/emmas/src/main/java/dev/brianweloba/dao/RsvpDAO.java index cd14606..38c5799 100644 --- a/emmas/src/main/java/dev/brianweloba/dao/RsvpDAO.java +++ b/emmas/src/main/java/dev/brianweloba/dao/RsvpDAO.java @@ -6,6 +6,8 @@ import jakarta.persistence.EntityManager; import jakarta.persistence.NoResultException; +import java.util.List; + public class RsvpDAO { public void create(RSVP rsvp, Long eventId) { EntityManager manager = HibernateUtil.getEntityManager(); @@ -71,6 +73,14 @@ private Long getCurrentGuestCount(EntityManager manager, Event event) { } } + public List getEventAttendees(Long eventId) { + try (EntityManager manager = HibernateUtil.getEntityManager()) { + return manager.createQuery("SELECT r FROM RSVP r WHERE r.event.id = :eventId", RSVP.class) + .setParameter("eventId", eventId) + .getResultList(); + } + } + // Additional useful method public Event findEventWithRsvps(Long eventId) { try (EntityManager manager = HibernateUtil.getEntityManager()) { diff --git a/emmas/src/main/java/dev/brianweloba/model/User.java b/emmas/src/main/java/dev/brianweloba/model/User.java new file mode 100644 index 0000000..09bf028 --- /dev/null +++ b/emmas/src/main/java/dev/brianweloba/model/User.java @@ -0,0 +1,58 @@ +package dev.brianweloba.model; + +import jakarta.persistence.*; +import jakarta.validation.constraints.Email; +import jakarta.validation.constraints.NotBlank; +import jakarta.validation.constraints.Size; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +import java.util.Date; + +@Entity +@Table(name = "users") +@Getter +@Setter +@NoArgsConstructor +public class User { + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + private Long id; + + @NotBlank(message = "Username is required") + @Size(min = 3, max = 50, message = "Username must be between 3 and 50 characters") + @Column(nullable = false, unique = true) + private String username; + + @NotBlank(message = "Password is required") + @Size(min = 8, message = "Password must be at least 8 characters long") + @Column(nullable = false) + private String password; + + @NotBlank(message = "Email is required") + @Email(message = "Email should be valid") + @Column(nullable = false, unique = true) + private String email; + + @Column(name = "created_at", insertable = false, updatable = false, + columnDefinition = "TIMESTAMP DEFAULT CURRENT_TIMESTAMP") + private Date createdAt; + + @Column(name = "updated_at", insertable = false, updatable = false, + columnDefinition = "TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + private Date updatedAt; + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof User)) return false; + return id != null && id.equals(((User) o).getId()); + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } +} + diff --git a/emmas/src/main/java/dev/brianweloba/servlet/RSVPServlet.java b/emmas/src/main/java/dev/brianweloba/servlet/RSVPServlet.java index 8a82700..ddb6458 100644 --- a/emmas/src/main/java/dev/brianweloba/servlet/RSVPServlet.java +++ b/emmas/src/main/java/dev/brianweloba/servlet/RSVPServlet.java @@ -4,18 +4,57 @@ import dev.brianweloba.dao.EventDAO; import dev.brianweloba.dao.RsvpDAO; import dev.brianweloba.model.RSVP; +import jakarta.servlet.ServletException; import jakarta.servlet.annotation.WebServlet; import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import java.io.IOException; +import java.util.List; @WebServlet("/events/rsvp") public class RSVPServlet extends HttpServlet { private final RsvpDAO rsvpDAO = new RsvpDAO(); private final EventDAO eventDAO = new EventDAO(); + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException,IOException { + try{ + String eventIdParam = request.getParameter("eventId"); + if(eventIdParam == null || eventIdParam.trim().isEmpty()){ + request.setAttribute("error", "Event ID is required"); + request.getRequestDispatcher("/WEB-INF/views/error.jsp").forward(request, response); + return; + } + long eventId; + try { + eventId = Long.parseLong(eventIdParam); + } catch (NumberFormatException e) { + request.setAttribute("error", "Invalid Event ID format"); + request.getRequestDispatcher("/WEB-INF/views/error.jsp").forward(request, response); + return; + } + + List attendees = rsvpDAO.getEventAttendees(eventId); + + int totalGuests = 0; + for (RSVP rsvp : attendees) { + totalGuests += rsvp.getGuests(); + } + + request.setAttribute("attendees", attendees); + request.setAttribute("totalGuests", totalGuests); + + request.getRequestDispatcher("/WEB-INF/views/attendees.jsp").forward(request, response); + + } catch (Exception e) { + request.setAttribute("error", "Error fetching attendees: " + e.getMessage()); + request.getRequestDispatcher("/WEB-INF/views/error.jsp").forward(request, response); + } + + + } @Override protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException { response.setContentType("application/json"); diff --git a/emmas/src/main/webapp/WEB-INF/views/attendees.jsp b/emmas/src/main/webapp/WEB-INF/views/attendees.jsp new file mode 100644 index 0000000..8d0d64d --- /dev/null +++ b/emmas/src/main/webapp/WEB-INF/views/attendees.jsp @@ -0,0 +1,59 @@ +<%@ taglib prefix="c" uri="http://java.sun.com/jsp/jstl/core" %> +<%@ taglib prefix="fmt" uri="http://java.sun.com/jsp/jstl/fmt" %> + + + + +
+
+ + + Showing ${attendees.size()} RSVPs with a total of ${totalGuests} attendees + + + No RSVPs yet + + +
+ +
+ + +
    + +
  • +
    +
    + + + +
    +
    +

    + +

    +

    + +

    +
    +
    + + + +${attendee.guests - 1} guests + + + 1 person + + +
    +
    +
  • +
    +
+
+ +

No one has RSVP'd to this event yet.

+
+
+
+
\ No newline at end of file diff --git a/emmas/src/main/webapp/WEB-INF/views/eventsBody.jsp b/emmas/src/main/webapp/WEB-INF/views/eventsBody.jsp index 447549d..99c275a 100644 --- a/emmas/src/main/webapp/WEB-INF/views/eventsBody.jsp +++ b/emmas/src/main/webapp/WEB-INF/views/eventsBody.jsp @@ -29,18 +29,10 @@ -
`, - ``, - ``, - ``, - ``, - `${event.eventCapacity}`, - ``, - `` - )"> +
+
${event.eventType}
@@ -87,7 +79,8 @@
-
+
@@ -95,16 +88,17 @@
Attendance: - - - - - - ${totalGuests}/${event.eventCapacity} - (${rsvpCount} RSVPs) - + + + + + + ${totalGuests}/${event.eventCapacity} + (${rsvpCount} RSVPs) +
+
@@ -119,7 +113,18 @@ Hosted by: ${event.eventHost}
-
- +