from flask import Flask, request, jsonify import torch from transformers import MobileBertTokenizer, MobileBertForSequenceClassification # Initialize Flask app app = Flask(__name__) # Set device (GPU if available, otherwise CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load tokenizer and model tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased") model = MobileBertForSequenceClassification.from_pretrained("cssupport/mobilebert-sql-injection-detect") model.to(device) model.eval() # Function to predict SQL injection def predict(text): inputs = tokenizer(text, padding=False, truncation=True, return_tensors="pt", max_length=512) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits probabilities = torch.softmax(logits, dim=1) predicted_class = torch.argmax(probabilities, dim=1).item() confidence = probabilities[0][predicted_class].item() return predicted_class, confidence # Define API endpoint @app.route("/predict", methods=["POST"]) def classify_query(): data = request.json if "query" not in data: return jsonify({"error": "Missing 'query' in request"}), 400 query = data["query"] predicted_class, confidence = predict(query) # Thresholding (if confidence > 0.7, mark as SQL Injection) is_vulnerable = predicted_class == 1 and confidence > 0.7 result = { "query": query, "classification": "SQL Injection Detected" if is_vulnerable else "No SQL Injection Detected", "confidence": round(confidence, 2) } return jsonify(result) # Run Flask server if __name__ == "__main__": app.run(host="0.0.0.0", port=5000, debug=True)