from flask import Flask, request, jsonify import torch from transformers import MobileBertTokenizer, MobileBertForSequenceClassification print("Starting server...") # initialize Flask app app = Flask(__name__) # set device, use gpu if available device = torch.device("cpu") # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # load tokenizer and model print("Loading model...") tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased") model = MobileBertForSequenceClassification.from_pretrained("cssupport/mobilebert-sql-injection-detect") model.to(device) model.eval() print("Model loaded") # function for model to predict sql injection def predict(text): inputs = tokenizer(text, padding=False, truncation=True, return_tensors="pt", max_length=512) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits probabilities = torch.softmax(logits, dim=1) predicted_class = torch.argmax(probabilities, dim=1).item() confidence = probabilities[0][predicted_class].item() return predicted_class, confidence # the api endpoint @app.route("/predict", methods=["POST"]) def classify_query(): data = request.json if "query" not in data: return jsonify({"error": "Missing 'query' in request"}), 400 query = data["query"] predicted_class, confidence = predict(query) # if >0.7, then mark as bad is_vulnerable = predicted_class == 1 and confidence > 0.7 result = { "query": query, "result": "fail" if is_vulnerable else "pass", "confidence": round(confidence, 2) } return jsonify(result) # run the flask server if __name__ == "__main__": app.run(host="0.0.0.0", port=5000, debug=False)