Skip to content

Commit

Permalink
Merge pull request #49 from pranav-bhatt/master
Browse files Browse the repository at this point in the history
ImageHub(Feature addition): Generalised the paths and modularized the filter for easy addition of rules
  • Loading branch information
leecalcote authored Mar 21, 2021
2 parents 636118e + 7b4e81f commit 7963237
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 105 deletions.
9 changes: 4 additions & 5 deletions rate-limit-filter/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ edition = "2018"
crate-type = ["cdylib"]

[dependencies]
wasm-bindgen = "0.2"
base64 = "0.13.0"
bincode = "1.0"
proxy-wasm = "^0.1"
serde = { version = "1.0", default-features = false, features = ["derive"] }
bincode = "1.0"
#postgres = "^0.19.0"
base64 = "0.12.1"
serde_json ="1.0"
serde_json ="1.0"
wasm-bindgen = "0.2"
Binary file modified rate-limit-filter/pkg/rate_limit_filter_bg.wasm
Binary file not shown.
5 changes: 5 additions & 0 deletions rate-limit-filter/src/json_parse/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod rules;

pub use rules::JsonPath;
pub use rules::RateLimiterJson;
pub use rules::Rule;
21 changes: 21 additions & 0 deletions rate-limit-filter/src/json_parse/rules.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use serde::Deserialize;

#[derive(Clone, Debug, Deserialize, PartialEq, PartialOrd)]
pub struct JsonPath {
pub name: String,
pub rule: Rule,
}

#[derive(Clone, Debug, Deserialize, PartialEq, PartialOrd)]
#[serde(rename_all(deserialize = "kebab-case"))]
#[serde(tag = "ruleType", content = "parameters")]
pub enum Rule {
RateLimiter(Vec<RateLimiterJson>),
None,
}

#[derive(Clone, Debug, Deserialize, PartialEq, PartialOrd)]
pub struct RateLimiterJson {
pub identifier: String,
pub limit: u32,
}
222 changes: 152 additions & 70 deletions rate-limit-filter/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,28 @@
mod json_parse;
mod rate_limiter;

//use postgres::{Client, Error, NoTls};
use json_parse::{JsonPath, RateLimiterJson, Rule};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use rate_limiter::RateLimiter;
use serde::{Deserialize, Serialize};
use serde::Deserialize;

use std::collections::HashMap;
use std::time::SystemTime;

// We need to make sure a HTTP root context is created and initialized when the filter is initialized.
// The _start() function initialises this root context
#[no_mangle]
pub fn _start() {
proxy_wasm::set_log_level(LogLevel::Info);
proxy_wasm::set_http_context(|_context_id, _root_context_id| -> Box<dyn HttpContext> {
Box::new(UpstreamCall::new())
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(UpstreamCallRoot {
config_json: HashMap::new(),
})
});
}

#[derive(Debug)]
struct UpstreamCall {
//paths: Vec<String>,
}

impl UpstreamCall {
fn new() -> Self {
return Self {
//paths: retrieve().unwrap(),
};
}
}

/*
fn retrieve() -> Result<Vec<String>, Error> {
let mut client = Client::connect("host=localhost user=postgres dbname=mesherydb", NoTls)?;
let array = client
.query("SELECT PathName FROM Paths", &[])?
.iter()
.map(|x| x.get(0))
.collect();
Ok(array)
}
*/

//to be removed
static ALLOWED_PATHS: [&str; 4] = ["/auth", "/signup", "/upgrade", "/pull"];
// Defining standard CORS headers
static CORS_HEADERS: [(&str, &str); 5] = [
("Powered-By", "proxy-wasm"),
("Access-Control-Allow-Origin", "*"),
Expand All @@ -51,57 +31,131 @@ static CORS_HEADERS: [(&str, &str); 5] = [
("Access-Control-Max-Age", "3600"),
];

#[derive(Serialize, Deserialize, Debug)]
// This struct is what the JWT token sent by the user will deserialize to
#[derive(Deserialize, Debug)]
struct Data {
username: String,
plan: String,
}

// This is the instance of a call made. It sorta derives from the root context
#[derive(Debug)]
struct UpstreamCall {
config_json: HashMap<String, Rule>,
}

impl UpstreamCall {
// Takes in the HashMap created in the root context mapping path name to rule type
fn new(json_hm: &HashMap<String, Rule>) -> Self {
Self {
//TODO this clone is super heavy, find a way to get rid of it
config_json: json_hm.clone(),
}
}

// Check if the path specified in the incoming request's path header has rule type None.
// Returns Option containing path name that was sent
fn rule_is_none(&self, path: String) -> Option<String> {
let rule = self.config_json.get(&path).unwrap();
// checking based only on type
if std::mem::discriminant(rule) == std::mem::discriminant(&Rule::None) {
return Some(path);
}
return None;
}

// Check if the path specified in the incoming request's path header has rule type RateLimiter.
// Returns Option containing vector of RateLimiterJson objects (list of plan names with limits)
fn rule_is_rate_limiter(&self, path: String) -> Option<Vec<RateLimiterJson>> {
let rule = self.config_json.get(&path).unwrap();
// checking based only on type
if std::mem::discriminant(rule) == std::mem::discriminant(&Rule::RateLimiter(Vec::new())) {
if let Rule::RateLimiter(plans_vec) = rule {
return Some(plans_vec.to_vec());
}
}
return None;
}
}

impl Context for UpstreamCall {}

impl HttpContext for UpstreamCall {
fn on_http_request_headers(&mut self, _num_headers: usize) -> Action {
// Options
if let Some(method) = self.get_http_request_header(":method") {
if method == "OPTIONS" {
self.send_http_response(204, CORS_HEADERS.to_vec(), None);
return Action::Pause;
}
}
if let Some(path) = self.get_http_request_header(":path") {
if ALLOWED_PATHS.binary_search(&path.as_str()).is_ok() {
return Action::Continue;
}

// Action for rule type: None
if let Some(_) = self.rule_is_none(self.get_http_request_header(":path").unwrap()) {
return Action::Continue;
}
/*
if let Some(path) = self.get_http_request_header(":path") {
if self.paths.binary_search(&path.to_string()).is_ok() {
return Action::Continue;
}
}*/
if let Some(header) = self.get_http_request_header("Authorization") {
if let Ok(token) = base64::decode(header) {
let obj: Data = serde_json::from_slice(&token).unwrap();
proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?}", obj).as_str()).ok();
let curr = self.get_current_time();
let tm = curr.duration_since(SystemTime::UNIX_EPOCH).unwrap();
let mn = (tm.as_secs() / 60) % 60;
let _sc = tm.as_secs() % 60;
let mut rl = RateLimiter::get(obj.username, obj.plan);

let mut headers = CORS_HEADERS.to_vec();
let count: String;

if !rl.update(mn as i32) {

// Action for rule type: RateLimiter
if let Some(plans_vec) =
self.rule_is_rate_limiter(self.get_http_request_header(":path").unwrap())
{
if let Some(header) = self.get_http_request_header("Authorization") {
// Decoding JWT token
if let Ok(token) = base64::decode(header) {
//Deserializing token
let obj: Data = serde_json::from_slice(&token).unwrap();

proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?}", obj).as_str())
.ok();

// Since the rate limit works on a rate per minute based quota, we find current time
let curr = self.get_current_time();
let tm = curr.duration_since(SystemTime::UNIX_EPOCH).unwrap();
let mn = (tm.as_secs() / 60) % 60;
let _sc = tm.as_secs() % 60;

// Initialise RateLimiter object
let mut rl = RateLimiter::get(&obj.username, &obj.plan);

// Initialising headers to send back
let mut headers = CORS_HEADERS.to_vec();
let count: String;

// Extracting limits based on plan stated in JWT token from the corresponding RateLimiterJson
let limit = plans_vec
.into_iter()
.filter(|x| x.identifier == obj.plan)
.map(|x| x.limit)
.collect::<Vec<u32>>();

// Checking if the appropriate plan exists
if limit.len() != 1 {
self.send_http_response(
429,
headers,
Some(b"Invalid plan name or duplicate plan names defined.\n"),
);
return Action::Pause;
}

//Update request count in RateLimiter object, and check if it exceeds limits
if rl.update(mn as i32) > limit[0] {
count = rl.count.to_string();
headers
.append(&mut vec![("x-rate-limit", &count), ("x-app-user", &rl.key)]);
self.send_http_response(429, headers, Some(b"Limit exceeded.\n"));
rl.set();
return Action::Pause;
}
proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?}", &rl).as_str())
.ok();
// set the new count in headers, and proxy_wasm storage
count = rl.count.to_string();
headers.append(&mut vec![("x-rate-limit", &count), ("x-app-user", &rl.key)]);
self.send_http_response(429, headers, Some(b"Limit exceeded.\n"));
rl.set();
return Action::Pause;
headers.append(&mut vec![("x-rate-limit", &count), ("x-app-user", &rl.key)]);
self.send_http_response(200, headers, Some(b"All Good!\n"));
return Action::Continue;
}
proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?}", &rl).as_str()).ok();
count = rl.count.to_string();
rl.set();
headers.append(&mut vec![("x-rate-limit", &count), ("x-app-user", &rl.key)]);
self.send_http_response(200, headers, Some(b"All Good!\n"));
return Action::Continue;
}
}
self.send_http_response(401, CORS_HEADERS.to_vec(), Some(b"Unauthorized\n"));
Expand All @@ -115,10 +169,38 @@ impl HttpContext for UpstreamCall {
}
}

impl UpstreamCall {
// fn retrieve_rl(&self) -> RateLimiter {
// }
struct UpstreamCallRoot {
config_json: HashMap<String, Rule>,
}

impl Context for UpstreamCall {}
impl RootContext for UpstreamCall {}
impl Context for UpstreamCallRoot {}
impl<'a> RootContext for UpstreamCallRoot {
//TODO: Revisit this once the read only feature is released in Istio 1.10
// Get Base64 encoded JSON from envoy config file when WASM VM starts
fn on_vm_start(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_configuration() {
// bytestring passed by VM -> String of base64 encoded JSON
let config_str = String::from_utf8(config_bytes).unwrap();
// String of base64 encoded JSON -> bytestring of decoded JSON
let config_b64 = base64::decode(config_str).unwrap();
// bytestring of decoded JSON -> String of decoded JSON
let json_str = String::from_utf8(config_b64).unwrap();
// Deserializing JSON String into vector of JsonPath objects
let json_vec: Vec<JsonPath> = serde_json::from_str(&json_str).unwrap();
// Creating HashMap of pattern ("path name", "rule type") and saving into UpstreamCallRoot object
for i in json_vec {
self.config_json.insert(i.name, i.rule);
}
}
true
}

fn create_http_context(&self, _: u32) -> Option<Box<dyn HttpContext>> {
// creating UpstreamCall object for each new call
Some(Box::new(UpstreamCall::new(&self.config_json)))
}

fn get_type(&self) -> Option<ContextType> {
Some(ContextType::HttpContext)
}
}
3 changes: 3 additions & 0 deletions rate-limit-filter/src/rate_limiter/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod rate_limiter;

pub use rate_limiter::*;
Original file line number Diff line number Diff line change
Expand Up @@ -5,66 +5,48 @@ use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RateLimiter {
pub rpm: Option<u32>,
// Tracks time
pub min: i32,
// Tracks number of calls made
pub count: u32,
// stores a key(username according to example)
pub key: String,
}

impl RateLimiter {
fn new(key: &String, plan: &String) -> Self {
let limit = match plan.as_str() {
"Enterprise" => Some(100),
"Team" => Some(50),
"Personal" => Some(10),
_ => None,
};
fn new(key: &String, _plan: &String) -> Self {
Self {
rpm: limit,
min: -1,
count: 0,
key: key.clone(),
}
}
pub fn get(key: String, plan: String) -> Self {
// Get key and plan from proxy_wasm shared data store (username+plan name)
pub fn get(key: &String, plan: &String) -> Self {
if let Ok(data) = proxy_wasm::hostcalls::get_shared_data(&key.clone()) {
if let Some(data) = data.0 {
let data: Option<Self> = bincode::deserialize(&data).unwrap_or(None);
if let Some(mut obj) = data {
let limit = match plan.as_str() {
"Enterprise" => Some(100),
"Team" => Some(50),
"Personal" => Some(10),
_ => None,
};
obj.rpm = limit;
if let Some(obj) = data {
return obj;
}
}
}
return Self::new(&key, &plan);
}
// Set key and plan in proxy_wasm shared data store (username+plan name)
pub fn set(&self) {
let target: Option<Self> = Some(self.clone());
let encoded: Vec<u8> = bincode::serialize(&target).unwrap();
proxy_wasm::hostcalls::set_shared_data(&self.key.clone(), Some(&encoded), None).ok();
}
pub fn update(&mut self, time: i32) -> bool {
// Update time (minute by minute) and increment count
pub fn update(&mut self, time: i32) -> u32 {
if self.min != time {
self.min = time;
self.count = 0;
}
self.count += 1;
proxy_wasm::hostcalls::log(
LogLevel::Debug,
format!("Obj {:?} {:?}", self.count, self.rpm).as_str(),
)
.ok();
if let Some(sm) = self.rpm {
if self.count > sm {
return false;
}
}
return true;
proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?} ", self.count).as_str()).ok();
self.count
}
}

0 comments on commit 7963237

Please sign in to comment.