cspj-application/server-ml/main.py

61 lines
1.9 KiB
Python
Raw Permalink Normal View History

2025-02-06 02:43:26 +08:00
from flask import Flask, request, jsonify
2025-02-06 06:56:15 +08:00
import torch
from transformers import MobileBertTokenizer, MobileBertForSequenceClassification
2025-02-06 01:28:40 +08:00
2025-02-13 02:44:00 +08:00
print("Starting server...")
# initialize Flask app
2025-02-06 02:43:26 +08:00
app = Flask(__name__)
2025-02-06 01:28:40 +08:00
2025-02-13 02:44:00 +08:00
# set device, use gpu if available
device = torch.device("cpu")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2025-02-06 06:56:15 +08:00
2025-02-13 02:44:00 +08:00
# load tokenizer and model
print("Loading model...")
2025-02-06 06:56:15 +08:00
tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")
model = MobileBertForSequenceClassification.from_pretrained("cssupport/mobilebert-sql-injection-detect")
model.to(device)
model.eval()
2025-02-13 02:44:00 +08:00
print("Model loaded")
2025-02-06 06:56:15 +08:00
2025-02-13 02:44:00 +08:00
# function for model to predict sql injection
2025-02-06 06:56:15 +08:00
def predict(text):
inputs = tokenizer(text, padding=False, truncation=True, return_tensors="pt", max_length=512)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class].item()
return predicted_class, confidence
2025-02-13 02:44:00 +08:00
# the api endpoint
2025-02-06 06:56:15 +08:00
@app.route("/predict", methods=["POST"])
def classify_query():
data = request.json
if "query" not in data:
return jsonify({"error": "Missing 'query' in request"}), 400
query = data["query"]
predicted_class, confidence = predict(query)
2025-02-13 02:44:00 +08:00
# if >0.7, then mark as bad
2025-02-06 06:56:15 +08:00
is_vulnerable = predicted_class == 1 and confidence > 0.7
result = {
"query": query,
2025-02-13 02:44:00 +08:00
"result": "fail" if is_vulnerable else "pass",
2025-02-06 06:56:15 +08:00
"confidence": round(confidence, 2)
}
return jsonify(result)
2025-02-13 02:44:00 +08:00
# run the flask server
2025-02-06 06:56:15 +08:00
if __name__ == "__main__":
2025-02-13 02:44:00 +08:00
app.run(host="0.0.0.0", port=5000, debug=False)