cspj-application/server-ml/main.py

61 lines
1.9 KiB
Python

from flask import Flask, request, jsonify
import torch
from transformers import MobileBertTokenizer, MobileBertForSequenceClassification
print("Starting server...")
# initialize Flask app
app = Flask(__name__)
# set device, use gpu if available
device = torch.device("cpu")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load tokenizer and model
print("Loading model...")
tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")
model = MobileBertForSequenceClassification.from_pretrained("cssupport/mobilebert-sql-injection-detect")
model.to(device)
model.eval()
print("Model loaded")
# function for model to predict sql injection
def predict(text):
inputs = tokenizer(text, padding=False, truncation=True, return_tensors="pt", max_length=512)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class].item()
return predicted_class, confidence
# the api endpoint
@app.route("/predict", methods=["POST"])
def classify_query():
data = request.json
if "query" not in data:
return jsonify({"error": "Missing 'query' in request"}), 400
query = data["query"]
predicted_class, confidence = predict(query)
# if >0.7, then mark as bad
is_vulnerable = predicted_class == 1 and confidence > 0.7
result = {
"query": query,
"result": "fail" if is_vulnerable else "pass",
"confidence": round(confidence, 2)
}
return jsonify(result)
# run the flask server
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000, debug=False)