From b7df6fcf2f3e039cdaaacaffc8aa323f9b4c225e Mon Sep 17 00:00:00 2001 From: David Calavera Date: Tue, 28 Nov 2023 09:49:08 -0800 Subject: [PATCH] Extract the request ID without allocating extra memory. (#735) Changes the way that the Context is initialized to receive the request ID as an argument. This way we also avoid allocating additional memory for it. Signed-off-by: David Calavera --- lambda-runtime/src/lib.rs | 19 ++------ lambda-runtime/src/types.rs | 93 ++++++++++++++++++++++++------------- 2 files changed, 67 insertions(+), 45 deletions(-) diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index 5404fb96..ccd35ab0 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -17,7 +17,6 @@ use hyper::{ use lambda_runtime_api_client::Client; use serde::{Deserialize, Serialize}; use std::{ - convert::TryFrom, env, fmt::{self, Debug, Display}, future::Future, @@ -41,6 +40,8 @@ mod types; use requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest}; pub use types::{Context, FunctionResponse, IntoFunctionResponse, LambdaEvent, MetadataPrelude, StreamResponse}; +use types::invoke_request_id; + /// Error type that lambdas may result in pub type Error = lambda_runtime_api_client::Error; @@ -121,6 +122,7 @@ where trace!("New event arrived (run loop)"); let event = next_event_response?; let (parts, body) = event.into_parts(); + let request_id = invoke_request_id(&parts.headers)?; #[cfg(debug_assertions)] if parts.status == http::StatusCode::NO_CONTENT { @@ -130,19 +132,8 @@ where continue; } - let ctx: Context = Context::try_from((self.config.clone(), parts.headers))?; - let request_id = &ctx.request_id.clone(); - - let request_span = match &ctx.xray_trace_id { - Some(trace_id) => { - env::set_var("_X_AMZN_TRACE_ID", trace_id); - tracing::info_span!("Lambda runtime invoke", requestId = request_id, xrayTraceId = trace_id) - } - None => { - env::remove_var("_X_AMZN_TRACE_ID"); - tracing::info_span!("Lambda runtime invoke", requestId = request_id) - } - }; + let ctx: Context = Context::new(request_id, self.config.clone(), &parts.headers)?; + let request_span = ctx.request_span(); // Group the handling in one future and instrument it with the span async { diff --git a/lambda-runtime/src/types.rs b/lambda-runtime/src/types.rs index a252475b..82d9b21f 100644 --- a/lambda-runtime/src/types.rs +++ b/lambda-runtime/src/types.rs @@ -1,15 +1,16 @@ use crate::{Error, RefConfig}; use base64::prelude::*; use bytes::Bytes; -use http::{HeaderMap, HeaderValue, StatusCode}; +use http::{header::ToStrError, HeaderMap, HeaderValue, StatusCode}; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, - convert::TryFrom, + env, fmt::Debug, time::{Duration, SystemTime}, }; use tokio_stream::Stream; +use tracing::Span; #[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -120,11 +121,10 @@ pub struct Context { pub env_config: RefConfig, } -impl TryFrom<(RefConfig, HeaderMap)> for Context { - type Error = Error; - fn try_from(data: (RefConfig, HeaderMap)) -> Result { - let env_config = data.0; - let headers = data.1; +impl Context { + /// Create a new [Context] struct based on the fuction configuration + /// and the incoming request data. + pub fn new(request_id: &str, env_config: RefConfig, headers: &HeaderMap) -> Result { let client_context: Option = if let Some(value) = headers.get("lambda-runtime-client-context") { serde_json::from_str(value.to_str()?)? } else { @@ -138,11 +138,7 @@ impl TryFrom<(RefConfig, HeaderMap)> for Context { }; let ctx = Context { - request_id: headers - .get("lambda-runtime-aws-request-id") - .expect("missing lambda-runtime-aws-request-id header") - .to_str()? - .to_owned(), + request_id: request_id.to_owned(), deadline: headers .get("lambda-runtime-deadline-ms") .expect("missing lambda-runtime-deadline-ms header") @@ -165,13 +161,37 @@ impl TryFrom<(RefConfig, HeaderMap)> for Context { Ok(ctx) } -} -impl Context { /// The execution deadline for the current invocation. pub fn deadline(&self) -> SystemTime { SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline) } + + /// Create a new [`tracing::Span`] for an incoming invocation. + pub(crate) fn request_span(&self) -> Span { + match &self.xray_trace_id { + Some(trace_id) => { + env::set_var("_X_AMZN_TRACE_ID", trace_id); + tracing::info_span!( + "Lambda runtime invoke", + requestId = &self.request_id, + xrayTraceId = trace_id + ) + } + None => { + env::remove_var("_X_AMZN_TRACE_ID"); + tracing::info_span!("Lambda runtime invoke", requestId = &self.request_id) + } + } + } +} + +/// Extract the invocation request id from the incoming request. +pub(crate) fn invoke_request_id(headers: &HeaderMap) -> Result<&str, ToStrError> { + headers + .get("lambda-runtime-aws-request-id") + .expect("missing lambda-runtime-aws-request-id header") + .to_str() } /// Incoming Lambda request containing the event payload and context. @@ -313,7 +333,7 @@ mod test { HeaderValue::from_static("arn::myarn"), ); headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_ok()); } @@ -324,7 +344,7 @@ mod test { let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_ok()); } @@ -355,7 +375,7 @@ mod test { ); let config = Arc::new(Config::default()); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_ok()); let tried = tried.unwrap(); assert!(tried.client_context.is_some()); @@ -369,7 +389,7 @@ mod test { headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); headers.insert("lambda-runtime-client-context", HeaderValue::from_static("{}")); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_ok()); assert!(tried.unwrap().client_context.is_some()); } @@ -390,7 +410,7 @@ mod test { "lambda-runtime-cognito-identity", HeaderValue::from_str(&cognito_identity_str).unwrap(), ); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_ok()); let tried = tried.unwrap(); assert!(tried.identity.is_some()); @@ -412,7 +432,7 @@ mod test { HeaderValue::from_static("arn::myarn"), ); headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_err()); } @@ -427,7 +447,7 @@ mod test { "lambda-runtime-client-context", HeaderValue::from_static("BAD-Type,not JSON"), ); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_err()); } @@ -439,7 +459,7 @@ mod test { headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); headers.insert("lambda-runtime-cognito-identity", HeaderValue::from_static("{}")); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_err()); } @@ -454,14 +474,13 @@ mod test { "lambda-runtime-cognito-identity", HeaderValue::from_static("BAD-Type,not JSON"), ); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_err()); } #[test] #[should_panic] - #[allow(unused_must_use)] - fn context_with_missing_request_id_should_panic() { + fn context_with_missing_deadline_should_panic() { let config = Arc::new(Config::default()); let mut headers = HeaderMap::new(); @@ -471,15 +490,26 @@ mod test { HeaderValue::from_static("arn::myarn"), ); headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); - Context::try_from((config, headers)); + let _ = Context::new("id", config, &headers); } #[test] - #[should_panic] - #[allow(unused_must_use)] - fn context_with_missing_deadline_should_panic() { - let config = Arc::new(Config::default()); + fn invoke_request_id_should_not_panic() { + let mut headers = HeaderMap::new(); + headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); + headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); + headers.insert( + "lambda-runtime-invoked-function-arn", + HeaderValue::from_static("arn::myarn"), + ); + headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); + + let _ = invoke_request_id(&headers); + } + #[test] + #[should_panic] + fn invoke_request_id_should_panic() { let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); headers.insert( @@ -487,6 +517,7 @@ mod test { HeaderValue::from_static("arn::myarn"), ); headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); - Context::try_from((config, headers)); + + let _ = invoke_request_id(&headers); } }