1use crate::{
29 body::{Bytes, HttpBody},
30 BoxError,
31};
32use axum_core::{
33 body::Body,
34 response::{IntoResponse, Response},
35};
36use bytes::{BufMut, BytesMut};
37use futures_util::stream::{Stream, TryStream};
38use http_body::Frame;
39use pin_project_lite::pin_project;
40use std::{
41 fmt,
42 future::Future,
43 pin::Pin,
44 task::{ready, Context, Poll},
45 time::Duration,
46};
47use sync_wrapper::SyncWrapper;
48use tokio::time::Sleep;
49
50#[derive(Clone)]
52#[must_use]
53pub struct Sse<S> {
54 stream: S,
55 keep_alive: Option<KeepAlive>,
56}
57
58impl<S> Sse<S> {
59 pub fn new(stream: S) -> Self
64 where
65 S: TryStream<Ok = Event> + Send + 'static,
66 S::Error: Into<BoxError>,
67 {
68 Sse {
69 stream,
70 keep_alive: None,
71 }
72 }
73
74 pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
78 self.keep_alive = Some(keep_alive);
79 self
80 }
81}
82
83impl<S> fmt::Debug for Sse<S> {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 f.debug_struct("Sse")
86 .field("stream", &format_args!("{}", std::any::type_name::<S>()))
87 .field("keep_alive", &self.keep_alive)
88 .finish()
89 }
90}
91
92impl<S, E> IntoResponse for Sse<S>
93where
94 S: Stream<Item = Result<Event, E>> + Send + 'static,
95 E: Into<BoxError>,
96{
97 fn into_response(self) -> Response {
98 (
99 [
100 (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()),
101 (http::header::CACHE_CONTROL, "no-cache"),
102 ],
103 Body::new(SseBody {
104 event_stream: SyncWrapper::new(self.stream),
105 keep_alive: self.keep_alive.map(KeepAliveStream::new),
106 }),
107 )
108 .into_response()
109 }
110}
111
112pin_project! {
113 struct SseBody<S> {
114 #[pin]
115 event_stream: SyncWrapper<S>,
116 #[pin]
117 keep_alive: Option<KeepAliveStream>,
118 }
119}
120
121impl<S, E> HttpBody for SseBody<S>
122where
123 S: Stream<Item = Result<Event, E>>,
124{
125 type Data = Bytes;
126 type Error = E;
127
128 fn poll_frame(
129 self: Pin<&mut Self>,
130 cx: &mut Context<'_>,
131 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
132 let this = self.project();
133
134 match this.event_stream.get_pin_mut().poll_next(cx) {
135 Poll::Pending => {
136 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
137 keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
138 } else {
139 Poll::Pending
140 }
141 }
142 Poll::Ready(Some(Ok(event))) => {
143 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
144 keep_alive.reset();
145 }
146 Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
147 }
148 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
149 Poll::Ready(None) => Poll::Ready(None),
150 }
151 }
152}
153
154#[derive(Debug, Default, Clone)]
156#[must_use]
157pub struct Event {
158 buffer: BytesMut,
159 flags: EventFlags,
160}
161
162impl Event {
163 pub fn data<T>(mut self, data: T) -> Event
178 where
179 T: AsRef<str>,
180 {
181 if self.flags.contains(EventFlags::HAS_DATA) {
182 panic!("Called `EventBuilder::data` multiple times");
183 }
184
185 for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
186 self.field("data", line);
187 }
188
189 self.flags.insert(EventFlags::HAS_DATA);
190
191 self
192 }
193
194 #[cfg(feature = "json")]
204 pub fn json_data<T>(mut self, data: T) -> Result<Event, axum_core::Error>
205 where
206 T: serde::Serialize,
207 {
208 struct IgnoreNewLines<'a>(bytes::buf::Writer<&'a mut BytesMut>);
209 impl std::io::Write for IgnoreNewLines<'_> {
210 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
211 let mut last_split = 0;
212 for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) {
213 self.0.write_all(&buf[last_split..delimiter])?;
214 last_split = delimiter + 1;
215 }
216 self.0.write_all(&buf[last_split..])?;
217 Ok(buf.len())
218 }
219
220 fn flush(&mut self) -> std::io::Result<()> {
221 self.0.flush()
222 }
223 }
224 if self.flags.contains(EventFlags::HAS_DATA) {
225 panic!("Called `EventBuilder::json_data` multiple times");
226 }
227
228 self.buffer.extend_from_slice(b"data: ");
229 serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data)
230 .map_err(axum_core::Error::new)?;
231 self.buffer.put_u8(b'\n');
232
233 self.flags.insert(EventFlags::HAS_DATA);
234
235 Ok(self)
236 }
237
238 pub fn comment<T>(mut self, comment: T) -> Event
249 where
250 T: AsRef<str>,
251 {
252 self.field("", comment.as_ref());
253 self
254 }
255
256 pub fn event<T>(mut self, event: T) -> Event
271 where
272 T: AsRef<str>,
273 {
274 if self.flags.contains(EventFlags::HAS_EVENT) {
275 panic!("Called `EventBuilder::event` multiple times");
276 }
277 self.flags.insert(EventFlags::HAS_EVENT);
278
279 self.field("event", event.as_ref());
280
281 self
282 }
283
284 pub fn retry(mut self, duration: Duration) -> Event {
294 if self.flags.contains(EventFlags::HAS_RETRY) {
295 panic!("Called `EventBuilder::retry` multiple times");
296 }
297 self.flags.insert(EventFlags::HAS_RETRY);
298
299 self.buffer.extend_from_slice(b"retry:");
300
301 let secs = duration.as_secs();
302 let millis = duration.subsec_millis();
303
304 if secs > 0 {
305 self.buffer
307 .extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
308
309 if millis < 10 {
311 self.buffer.extend_from_slice(b"00");
312 } else if millis < 100 {
313 self.buffer.extend_from_slice(b"0");
314 }
315 }
316
317 self.buffer
319 .extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
320
321 self.buffer.put_u8(b'\n');
322
323 self
324 }
325
326 pub fn id<T>(mut self, id: T) -> Event
339 where
340 T: AsRef<str>,
341 {
342 if self.flags.contains(EventFlags::HAS_ID) {
343 panic!("Called `EventBuilder::id` multiple times");
344 }
345 self.flags.insert(EventFlags::HAS_ID);
346
347 let id = id.as_ref().as_bytes();
348 assert_eq!(
349 memchr::memchr(b'\0', id),
350 None,
351 "Event ID cannot contain null characters",
352 );
353
354 self.field("id", id);
355 self
356 }
357
358 fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
359 let value = value.as_ref();
360 assert_eq!(
361 memchr::memchr2(b'\r', b'\n', value),
362 None,
363 "SSE field value cannot contain newlines or carriage returns",
364 );
365 self.buffer.extend_from_slice(name.as_bytes());
366 self.buffer.put_u8(b':');
367 self.buffer.put_u8(b' ');
368 self.buffer.extend_from_slice(value);
369 self.buffer.put_u8(b'\n');
370 }
371
372 fn finalize(mut self) -> Bytes {
373 self.buffer.put_u8(b'\n');
374 self.buffer.freeze()
375 }
376}
377
378#[derive(Default, Debug, Copy, Clone, PartialEq)]
379struct EventFlags(u8);
380
381impl EventFlags {
382 const HAS_DATA: Self = Self::from_bits(0b0001);
383 const HAS_EVENT: Self = Self::from_bits(0b0010);
384 const HAS_RETRY: Self = Self::from_bits(0b0100);
385 const HAS_ID: Self = Self::from_bits(0b1000);
386
387 const fn bits(&self) -> u8 {
388 self.0
389 }
390
391 const fn from_bits(bits: u8) -> Self {
392 Self(bits)
393 }
394
395 const fn contains(&self, other: Self) -> bool {
396 self.bits() & other.bits() == other.bits()
397 }
398
399 fn insert(&mut self, other: Self) {
400 *self = Self::from_bits(self.bits() | other.bits());
401 }
402}
403
404#[derive(Debug, Clone)]
407#[must_use]
408pub struct KeepAlive {
409 event: Bytes,
410 max_interval: Duration,
411}
412
413impl KeepAlive {
414 pub fn new() -> Self {
416 Self {
417 event: Bytes::from_static(b":\n\n"),
418 max_interval: Duration::from_secs(15),
419 }
420 }
421
422 pub fn interval(mut self, time: Duration) -> Self {
426 self.max_interval = time;
427 self
428 }
429
430 pub fn text<I>(self, text: I) -> Self
439 where
440 I: AsRef<str>,
441 {
442 self.event(Event::default().comment(text))
443 }
444
445 pub fn event(mut self, event: Event) -> Self {
454 self.event = event.finalize();
455 self
456 }
457}
458
459impl Default for KeepAlive {
460 fn default() -> Self {
461 Self::new()
462 }
463}
464
465pin_project! {
466 #[derive(Debug)]
467 struct KeepAliveStream {
468 keep_alive: KeepAlive,
469 #[pin]
470 alive_timer: Sleep,
471 }
472}
473
474impl KeepAliveStream {
475 fn new(keep_alive: KeepAlive) -> Self {
476 Self {
477 alive_timer: tokio::time::sleep(keep_alive.max_interval),
478 keep_alive,
479 }
480 }
481
482 fn reset(self: Pin<&mut Self>) {
483 let this = self.project();
484 this.alive_timer
485 .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
486 }
487
488 fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
489 let this = self.as_mut().project();
490
491 ready!(this.alive_timer.poll(cx));
492
493 let event = this.keep_alive.event.clone();
494
495 self.reset();
496
497 Poll::Ready(event)
498 }
499}
500
501fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
502 MemchrSplit {
503 needle,
504 haystack: Some(haystack),
505 }
506}
507
508struct MemchrSplit<'a> {
509 needle: u8,
510 haystack: Option<&'a [u8]>,
511}
512
513impl<'a> Iterator for MemchrSplit<'a> {
514 type Item = &'a [u8];
515 fn next(&mut self) -> Option<Self::Item> {
516 let haystack = self.haystack?;
517 if let Some(pos) = memchr::memchr(self.needle, haystack) {
518 let (front, back) = haystack.split_at(pos);
519 self.haystack = Some(&back[1..]);
520 Some(front)
521 } else {
522 self.haystack.take()
523 }
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use crate::{routing::get, test_helpers::*, Router};
531 use futures_util::stream;
532 use serde_json::value::RawValue;
533 use std::{collections::HashMap, convert::Infallible};
534 use tokio_stream::StreamExt as _;
535
536 #[test]
537 fn leading_space_is_not_stripped() {
538 let no_leading_space = Event::default().data("\tfoobar");
539 assert_eq!(&*no_leading_space.finalize(), b"data: \tfoobar\n\n");
540
541 let leading_space = Event::default().data(" foobar");
542 assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n");
543 }
544
545 #[test]
546 fn valid_json_raw_value_chars_stripped() {
547 let json_string = "{\r\"foo\": \n\r\r \"bar\\n\"\n}";
548 let json_raw_value_event = Event::default()
549 .json_data(serde_json::from_str::<&RawValue>(json_string).unwrap())
550 .unwrap();
551 assert_eq!(
552 &*json_raw_value_event.finalize(),
553 format!("data: {}\n\n", json_string.replace(['\n', '\r'], "")).as_bytes()
554 );
555 }
556
557 #[crate::test]
558 async fn basic() {
559 let app = Router::new().route(
560 "/",
561 get(|| async {
562 let stream = stream::iter(vec![
563 Event::default().data("one").comment("this is a comment"),
564 Event::default()
565 .json_data(serde_json::json!({ "foo": "bar" }))
566 .unwrap(),
567 Event::default()
568 .event("three")
569 .retry(Duration::from_secs(30))
570 .id("unique-id"),
571 ])
572 .map(Ok::<_, Infallible>);
573 Sse::new(stream)
574 }),
575 );
576
577 let client = TestClient::new(app);
578 let mut stream = client.get("/").await;
579
580 assert_eq!(stream.headers()["content-type"], "text/event-stream");
581 assert_eq!(stream.headers()["cache-control"], "no-cache");
582
583 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
584 assert_eq!(event_fields.get("data").unwrap(), "one");
585 assert_eq!(event_fields.get("comment").unwrap(), "this is a comment");
586
587 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
588 assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}");
589 assert!(!event_fields.contains_key("comment"));
590
591 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
592 assert_eq!(event_fields.get("event").unwrap(), "three");
593 assert_eq!(event_fields.get("retry").unwrap(), "30000");
594 assert_eq!(event_fields.get("id").unwrap(), "unique-id");
595 assert!(!event_fields.contains_key("comment"));
596
597 assert!(stream.chunk_text().await.is_none());
598 }
599
600 #[tokio::test(start_paused = true)]
601 async fn keep_alive() {
602 const DELAY: Duration = Duration::from_secs(5);
603
604 let app = Router::new().route(
605 "/",
606 get(|| async {
607 let stream = stream::repeat_with(|| Event::default().data("msg"))
608 .map(Ok::<_, Infallible>)
609 .throttle(DELAY);
610
611 Sse::new(stream).keep_alive(
612 KeepAlive::new()
613 .interval(Duration::from_secs(1))
614 .text("keep-alive-text"),
615 )
616 }),
617 );
618
619 let client = TestClient::new(app);
620 let mut stream = client.get("/").await;
621
622 for _ in 0..5 {
623 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
625 assert_eq!(event_fields.get("data").unwrap(), "msg");
626
627 for _ in 0..4 {
629 tokio::time::sleep(Duration::from_secs(1)).await;
630 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
631 assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
632 }
633 }
634 }
635
636 #[tokio::test(start_paused = true)]
637 async fn keep_alive_ends_when_the_stream_ends() {
638 const DELAY: Duration = Duration::from_secs(5);
639
640 let app = Router::new().route(
641 "/",
642 get(|| async {
643 let stream = stream::repeat_with(|| Event::default().data("msg"))
644 .map(Ok::<_, Infallible>)
645 .throttle(DELAY)
646 .take(2);
647
648 Sse::new(stream).keep_alive(
649 KeepAlive::new()
650 .interval(Duration::from_secs(1))
651 .text("keep-alive-text"),
652 )
653 }),
654 );
655
656 let client = TestClient::new(app);
657 let mut stream = client.get("/").await;
658
659 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
661 assert_eq!(event_fields.get("data").unwrap(), "msg");
662
663 for _ in 0..4 {
665 tokio::time::sleep(Duration::from_secs(1)).await;
666 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
667 assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
668 }
669
670 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
672 assert_eq!(event_fields.get("data").unwrap(), "msg");
673
674 assert!(stream.chunk_text().await.is_none());
676 }
677
678 fn parse_event(payload: &str) -> HashMap<String, String> {
679 let mut fields = HashMap::new();
680
681 let mut lines = payload.lines().peekable();
682 while let Some(line) = lines.next() {
683 if line.is_empty() {
684 assert!(lines.next().is_none());
685 break;
686 }
687
688 let (mut key, value) = line.split_once(':').unwrap();
689 let value = value.trim();
690 if key.is_empty() {
691 key = "comment";
692 }
693 fields.insert(key.to_owned(), value.to_owned());
694 }
695
696 fields
697 }
698
699 #[test]
700 fn memchr_splitting() {
701 assert_eq!(
702 memchr_split(2, &[]).collect::<Vec<_>>(),
703 [&[]] as [&[u8]; 1]
704 );
705 assert_eq!(
706 memchr_split(2, &[2]).collect::<Vec<_>>(),
707 [&[], &[]] as [&[u8]; 2]
708 );
709 assert_eq!(
710 memchr_split(2, &[1]).collect::<Vec<_>>(),
711 [&[1]] as [&[u8]; 1]
712 );
713 assert_eq!(
714 memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
715 [&[1], &[]] as [&[u8]; 2]
716 );
717 assert_eq!(
718 memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
719 [&[], &[1]] as [&[u8]; 2]
720 );
721 assert_eq!(
722 memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
723 [&[1], &[], &[1]] as [&[u8]; 3]
724 );
725 }
726}