cspj-application/server-ml/main.py
2025-02-06 06:56:15 +08:00

56 lines
1.8 KiB
Python

from flask import Flask, request, jsonify
import torch
from transformers import MobileBertTokenizer, MobileBertForSequenceClassification
# Initialize Flask app
app = Flask(__name__)
# Set device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load tokenizer and model
tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")
model = MobileBertForSequenceClassification.from_pretrained("cssupport/mobilebert-sql-injection-detect")
model.to(device)
model.eval()
# Function to predict SQL injection
def predict(text):
inputs = tokenizer(text, padding=False, truncation=True, return_tensors="pt", max_length=512)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class].item()
return predicted_class, confidence
# Define API endpoint
@app.route("/predict", methods=["POST"])
def classify_query():
data = request.json
if "query" not in data:
return jsonify({"error": "Missing 'query' in request"}), 400
query = data["query"]
predicted_class, confidence = predict(query)
# Thresholding (if confidence > 0.7, mark as SQL Injection)
is_vulnerable = predicted_class == 1 and confidence > 0.7
result = {
"query": query,
"classification": "SQL Injection Detected" if is_vulnerable else "No SQL Injection Detected",
"confidence": round(confidence, 2)
}
return jsonify(result)
# Run Flask server
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000, debug=True)