axum/response/
sse.rs

1//! Server-Sent Events (SSE) responses.
2//!
3//! # Example
4//!
5//! ```
6//! use axum::{
7//!     Router,
8//!     routing::get,
9//!     response::sse::{Event, KeepAlive, Sse},
10//! };
11//! use std::{time::Duration, convert::Infallible};
12//! use tokio_stream::StreamExt as _ ;
13//! use futures_util::stream::{self, Stream};
14//!
15//! let app = Router::new().route("/sse", get(sse_handler));
16//!
17//! async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
18//!     // A `Stream` that repeats an event every second
19//!     let stream = stream::repeat_with(|| Event::default().data("hi!"))
20//!         .map(Ok)
21//!         .throttle(Duration::from_secs(1));
22//!
23//!     Sse::new(stream).keep_alive(KeepAlive::default())
24//! }
25//! # let _: Router = app;
26//! ```
27
28use crate::{
29    body::{Bytes, HttpBody},
30    BoxError,
31};
32use axum_core::{
33    body::Body,
34    response::{IntoResponse, Response},
35};
36use bytes::{BufMut, BytesMut};
37use futures_util::stream::{Stream, TryStream};
38use http_body::Frame;
39use pin_project_lite::pin_project;
40use std::{
41    fmt,
42    future::Future,
43    pin::Pin,
44    task::{ready, Context, Poll},
45    time::Duration,
46};
47use sync_wrapper::SyncWrapper;
48use tokio::time::Sleep;
49
50/// An SSE response
51#[derive(Clone)]
52#[must_use]
53pub struct Sse<S> {
54    stream: S,
55    keep_alive: Option<KeepAlive>,
56}
57
58impl<S> Sse<S> {
59    /// Create a new [`Sse`] response that will respond with the given stream of
60    /// [`Event`]s.
61    ///
62    /// See the [module docs](self) for more details.
63    pub fn new(stream: S) -> Self
64    where
65        S: TryStream<Ok = Event> + Send + 'static,
66        S::Error: Into<BoxError>,
67    {
68        Sse {
69            stream,
70            keep_alive: None,
71        }
72    }
73
74    /// Configure the interval between keep-alive messages.
75    ///
76    /// Defaults to no keep-alive messages.
77    pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
78        self.keep_alive = Some(keep_alive);
79        self
80    }
81}
82
83impl<S> fmt::Debug for Sse<S> {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        f.debug_struct("Sse")
86            .field("stream", &format_args!("{}", std::any::type_name::<S>()))
87            .field("keep_alive", &self.keep_alive)
88            .finish()
89    }
90}
91
92impl<S, E> IntoResponse for Sse<S>
93where
94    S: Stream<Item = Result<Event, E>> + Send + 'static,
95    E: Into<BoxError>,
96{
97    fn into_response(self) -> Response {
98        (
99            [
100                (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()),
101                (http::header::CACHE_CONTROL, "no-cache"),
102            ],
103            Body::new(SseBody {
104                event_stream: SyncWrapper::new(self.stream),
105                keep_alive: self.keep_alive.map(KeepAliveStream::new),
106            }),
107        )
108            .into_response()
109    }
110}
111
112pin_project! {
113    struct SseBody<S> {
114        #[pin]
115        event_stream: SyncWrapper<S>,
116        #[pin]
117        keep_alive: Option<KeepAliveStream>,
118    }
119}
120
121impl<S, E> HttpBody for SseBody<S>
122where
123    S: Stream<Item = Result<Event, E>>,
124{
125    type Data = Bytes;
126    type Error = E;
127
128    fn poll_frame(
129        self: Pin<&mut Self>,
130        cx: &mut Context<'_>,
131    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
132        let this = self.project();
133
134        match this.event_stream.get_pin_mut().poll_next(cx) {
135            Poll::Pending => {
136                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
137                    keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
138                } else {
139                    Poll::Pending
140                }
141            }
142            Poll::Ready(Some(Ok(event))) => {
143                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
144                    keep_alive.reset();
145                }
146                Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
147            }
148            Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
149            Poll::Ready(None) => Poll::Ready(None),
150        }
151    }
152}
153
154/// Server-sent event
155#[derive(Debug, Default, Clone)]
156#[must_use]
157pub struct Event {
158    buffer: BytesMut,
159    flags: EventFlags,
160}
161
162impl Event {
163    /// Set the event's data data field(s) (`data: <content>`)
164    ///
165    /// Newlines in `data` will automatically be broken across `data: ` fields.
166    ///
167    /// This corresponds to [`MessageEvent`'s data field].
168    ///
169    /// Note that events with an empty data field will be ignored by the browser.
170    ///
171    /// # Panics
172    ///
173    /// - Panics if `data` contains any carriage returns, as they cannot be transmitted over SSE.
174    /// - Panics if `data` or `json_data` have already been called.
175    ///
176    /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
177    pub fn data<T>(mut self, data: T) -> Event
178    where
179        T: AsRef<str>,
180    {
181        if self.flags.contains(EventFlags::HAS_DATA) {
182            panic!("Called `EventBuilder::data` multiple times");
183        }
184
185        for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
186            self.field("data", line);
187        }
188
189        self.flags.insert(EventFlags::HAS_DATA);
190
191        self
192    }
193
194    /// Set the event's data field to a value serialized as unformatted JSON (`data: <content>`).
195    ///
196    /// This corresponds to [`MessageEvent`'s data field].
197    ///
198    /// # Panics
199    ///
200    /// Panics if `data` or `json_data` have already been called.
201    ///
202    /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
203    #[cfg(feature = "json")]
204    pub fn json_data<T>(mut self, data: T) -> Result<Event, axum_core::Error>
205    where
206        T: serde::Serialize,
207    {
208        struct IgnoreNewLines<'a>(bytes::buf::Writer<&'a mut BytesMut>);
209        impl std::io::Write for IgnoreNewLines<'_> {
210            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
211                let mut last_split = 0;
212                for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) {
213                    self.0.write_all(&buf[last_split..delimiter])?;
214                    last_split = delimiter + 1;
215                }
216                self.0.write_all(&buf[last_split..])?;
217                Ok(buf.len())
218            }
219
220            fn flush(&mut self) -> std::io::Result<()> {
221                self.0.flush()
222            }
223        }
224        if self.flags.contains(EventFlags::HAS_DATA) {
225            panic!("Called `EventBuilder::json_data` multiple times");
226        }
227
228        self.buffer.extend_from_slice(b"data: ");
229        serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data)
230            .map_err(axum_core::Error::new)?;
231        self.buffer.put_u8(b'\n');
232
233        self.flags.insert(EventFlags::HAS_DATA);
234
235        Ok(self)
236    }
237
238    /// Set the event's comment field (`:<comment-text>`).
239    ///
240    /// This field will be ignored by most SSE clients.
241    ///
242    /// Unlike other functions, this function can be called multiple times to add many comments.
243    ///
244    /// # Panics
245    ///
246    /// Panics if `comment` contains any newlines or carriage returns, as they are not allowed in
247    /// comments.
248    pub fn comment<T>(mut self, comment: T) -> Event
249    where
250        T: AsRef<str>,
251    {
252        self.field("", comment.as_ref());
253        self
254    }
255
256    /// Set the event's name field (`event:<event-name>`).
257    ///
258    /// This corresponds to the `type` parameter given when calling `addEventListener` on an
259    /// [`EventSource`]. For example, `.event("update")` should correspond to
260    /// `.addEventListener("update", ...)`. If no event type is given, browsers will fire a
261    /// [`message` event] instead.
262    ///
263    /// [`EventSource`]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource
264    /// [`message` event]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource/message_event
265    ///
266    /// # Panics
267    ///
268    /// - Panics if `event` contains any newlines or carriage returns.
269    /// - Panics if this function has already been called on this event.
270    pub fn event<T>(mut self, event: T) -> Event
271    where
272        T: AsRef<str>,
273    {
274        if self.flags.contains(EventFlags::HAS_EVENT) {
275            panic!("Called `EventBuilder::event` multiple times");
276        }
277        self.flags.insert(EventFlags::HAS_EVENT);
278
279        self.field("event", event.as_ref());
280
281        self
282    }
283
284    /// Set the event's retry timeout field (`retry:<timeout>`).
285    ///
286    /// This sets how long clients will wait before reconnecting if they are disconnected from the
287    /// SSE endpoint. Note that this is just a hint: clients are free to wait for longer if they
288    /// wish, such as if they implement exponential backoff.
289    ///
290    /// # Panics
291    ///
292    /// Panics if this function has already been called on this event.
293    pub fn retry(mut self, duration: Duration) -> Event {
294        if self.flags.contains(EventFlags::HAS_RETRY) {
295            panic!("Called `EventBuilder::retry` multiple times");
296        }
297        self.flags.insert(EventFlags::HAS_RETRY);
298
299        self.buffer.extend_from_slice(b"retry:");
300
301        let secs = duration.as_secs();
302        let millis = duration.subsec_millis();
303
304        if secs > 0 {
305            // format seconds
306            self.buffer
307                .extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
308
309            // pad milliseconds
310            if millis < 10 {
311                self.buffer.extend_from_slice(b"00");
312            } else if millis < 100 {
313                self.buffer.extend_from_slice(b"0");
314            }
315        }
316
317        // format milliseconds
318        self.buffer
319            .extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
320
321        self.buffer.put_u8(b'\n');
322
323        self
324    }
325
326    /// Set the event's identifier field (`id:<identifier>`).
327    ///
328    /// This corresponds to [`MessageEvent`'s `lastEventId` field]. If no ID is in the event itself,
329    /// the browser will set that field to the last known message ID, starting with the empty
330    /// string.
331    ///
332    /// [`MessageEvent`'s `lastEventId` field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/lastEventId
333    ///
334    /// # Panics
335    ///
336    /// - Panics if `id` contains any newlines, carriage returns or null characters.
337    /// - Panics if this function has already been called on this event.
338    pub fn id<T>(mut self, id: T) -> Event
339    where
340        T: AsRef<str>,
341    {
342        if self.flags.contains(EventFlags::HAS_ID) {
343            panic!("Called `EventBuilder::id` multiple times");
344        }
345        self.flags.insert(EventFlags::HAS_ID);
346
347        let id = id.as_ref().as_bytes();
348        assert_eq!(
349            memchr::memchr(b'\0', id),
350            None,
351            "Event ID cannot contain null characters",
352        );
353
354        self.field("id", id);
355        self
356    }
357
358    fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
359        let value = value.as_ref();
360        assert_eq!(
361            memchr::memchr2(b'\r', b'\n', value),
362            None,
363            "SSE field value cannot contain newlines or carriage returns",
364        );
365        self.buffer.extend_from_slice(name.as_bytes());
366        self.buffer.put_u8(b':');
367        self.buffer.put_u8(b' ');
368        self.buffer.extend_from_slice(value);
369        self.buffer.put_u8(b'\n');
370    }
371
372    fn finalize(mut self) -> Bytes {
373        self.buffer.put_u8(b'\n');
374        self.buffer.freeze()
375    }
376}
377
378#[derive(Default, Debug, Copy, Clone, PartialEq)]
379struct EventFlags(u8);
380
381impl EventFlags {
382    const HAS_DATA: Self = Self::from_bits(0b0001);
383    const HAS_EVENT: Self = Self::from_bits(0b0010);
384    const HAS_RETRY: Self = Self::from_bits(0b0100);
385    const HAS_ID: Self = Self::from_bits(0b1000);
386
387    const fn bits(&self) -> u8 {
388        self.0
389    }
390
391    const fn from_bits(bits: u8) -> Self {
392        Self(bits)
393    }
394
395    const fn contains(&self, other: Self) -> bool {
396        self.bits() & other.bits() == other.bits()
397    }
398
399    fn insert(&mut self, other: Self) {
400        *self = Self::from_bits(self.bits() | other.bits());
401    }
402}
403
404/// Configure the interval between keep-alive messages, the content
405/// of each message, and the associated stream.
406#[derive(Debug, Clone)]
407#[must_use]
408pub struct KeepAlive {
409    event: Bytes,
410    max_interval: Duration,
411}
412
413impl KeepAlive {
414    /// Create a new `KeepAlive`.
415    pub fn new() -> Self {
416        Self {
417            event: Bytes::from_static(b":\n\n"),
418            max_interval: Duration::from_secs(15),
419        }
420    }
421
422    /// Customize the interval between keep-alive messages.
423    ///
424    /// Default is 15 seconds.
425    pub fn interval(mut self, time: Duration) -> Self {
426        self.max_interval = time;
427        self
428    }
429
430    /// Customize the text of the keep-alive message.
431    ///
432    /// Default is an empty comment.
433    ///
434    /// # Panics
435    ///
436    /// Panics if `text` contains any newline or carriage returns, as they are not allowed in SSE
437    /// comments.
438    pub fn text<I>(self, text: I) -> Self
439    where
440        I: AsRef<str>,
441    {
442        self.event(Event::default().comment(text))
443    }
444
445    /// Customize the event of the keep-alive message.
446    ///
447    /// Default is an empty comment.
448    ///
449    /// # Panics
450    ///
451    /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
452    /// comments.
453    pub fn event(mut self, event: Event) -> Self {
454        self.event = event.finalize();
455        self
456    }
457}
458
459impl Default for KeepAlive {
460    fn default() -> Self {
461        Self::new()
462    }
463}
464
465pin_project! {
466    #[derive(Debug)]
467    struct KeepAliveStream {
468        keep_alive: KeepAlive,
469        #[pin]
470        alive_timer: Sleep,
471    }
472}
473
474impl KeepAliveStream {
475    fn new(keep_alive: KeepAlive) -> Self {
476        Self {
477            alive_timer: tokio::time::sleep(keep_alive.max_interval),
478            keep_alive,
479        }
480    }
481
482    fn reset(self: Pin<&mut Self>) {
483        let this = self.project();
484        this.alive_timer
485            .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
486    }
487
488    fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
489        let this = self.as_mut().project();
490
491        ready!(this.alive_timer.poll(cx));
492
493        let event = this.keep_alive.event.clone();
494
495        self.reset();
496
497        Poll::Ready(event)
498    }
499}
500
501fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
502    MemchrSplit {
503        needle,
504        haystack: Some(haystack),
505    }
506}
507
508struct MemchrSplit<'a> {
509    needle: u8,
510    haystack: Option<&'a [u8]>,
511}
512
513impl<'a> Iterator for MemchrSplit<'a> {
514    type Item = &'a [u8];
515    fn next(&mut self) -> Option<Self::Item> {
516        let haystack = self.haystack?;
517        if let Some(pos) = memchr::memchr(self.needle, haystack) {
518            let (front, back) = haystack.split_at(pos);
519            self.haystack = Some(&back[1..]);
520            Some(front)
521        } else {
522            self.haystack.take()
523        }
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use crate::{routing::get, test_helpers::*, Router};
531    use futures_util::stream;
532    use serde_json::value::RawValue;
533    use std::{collections::HashMap, convert::Infallible};
534    use tokio_stream::StreamExt as _;
535
536    #[test]
537    fn leading_space_is_not_stripped() {
538        let no_leading_space = Event::default().data("\tfoobar");
539        assert_eq!(&*no_leading_space.finalize(), b"data: \tfoobar\n\n");
540
541        let leading_space = Event::default().data(" foobar");
542        assert_eq!(&*leading_space.finalize(), b"data:  foobar\n\n");
543    }
544
545    #[test]
546    fn valid_json_raw_value_chars_stripped() {
547        let json_string = "{\r\"foo\":  \n\r\r   \"bar\\n\"\n}";
548        let json_raw_value_event = Event::default()
549            .json_data(serde_json::from_str::<&RawValue>(json_string).unwrap())
550            .unwrap();
551        assert_eq!(
552            &*json_raw_value_event.finalize(),
553            format!("data: {}\n\n", json_string.replace(['\n', '\r'], "")).as_bytes()
554        );
555    }
556
557    #[crate::test]
558    async fn basic() {
559        let app = Router::new().route(
560            "/",
561            get(|| async {
562                let stream = stream::iter(vec![
563                    Event::default().data("one").comment("this is a comment"),
564                    Event::default()
565                        .json_data(serde_json::json!({ "foo": "bar" }))
566                        .unwrap(),
567                    Event::default()
568                        .event("three")
569                        .retry(Duration::from_secs(30))
570                        .id("unique-id"),
571                ])
572                .map(Ok::<_, Infallible>);
573                Sse::new(stream)
574            }),
575        );
576
577        let client = TestClient::new(app);
578        let mut stream = client.get("/").await;
579
580        assert_eq!(stream.headers()["content-type"], "text/event-stream");
581        assert_eq!(stream.headers()["cache-control"], "no-cache");
582
583        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
584        assert_eq!(event_fields.get("data").unwrap(), "one");
585        assert_eq!(event_fields.get("comment").unwrap(), "this is a comment");
586
587        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
588        assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}");
589        assert!(!event_fields.contains_key("comment"));
590
591        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
592        assert_eq!(event_fields.get("event").unwrap(), "three");
593        assert_eq!(event_fields.get("retry").unwrap(), "30000");
594        assert_eq!(event_fields.get("id").unwrap(), "unique-id");
595        assert!(!event_fields.contains_key("comment"));
596
597        assert!(stream.chunk_text().await.is_none());
598    }
599
600    #[tokio::test(start_paused = true)]
601    async fn keep_alive() {
602        const DELAY: Duration = Duration::from_secs(5);
603
604        let app = Router::new().route(
605            "/",
606            get(|| async {
607                let stream = stream::repeat_with(|| Event::default().data("msg"))
608                    .map(Ok::<_, Infallible>)
609                    .throttle(DELAY);
610
611                Sse::new(stream).keep_alive(
612                    KeepAlive::new()
613                        .interval(Duration::from_secs(1))
614                        .text("keep-alive-text"),
615                )
616            }),
617        );
618
619        let client = TestClient::new(app);
620        let mut stream = client.get("/").await;
621
622        for _ in 0..5 {
623            // first message should be an event
624            let event_fields = parse_event(&stream.chunk_text().await.unwrap());
625            assert_eq!(event_fields.get("data").unwrap(), "msg");
626
627            // then 4 seconds of keep-alive messages
628            for _ in 0..4 {
629                tokio::time::sleep(Duration::from_secs(1)).await;
630                let event_fields = parse_event(&stream.chunk_text().await.unwrap());
631                assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
632            }
633        }
634    }
635
636    #[tokio::test(start_paused = true)]
637    async fn keep_alive_ends_when_the_stream_ends() {
638        const DELAY: Duration = Duration::from_secs(5);
639
640        let app = Router::new().route(
641            "/",
642            get(|| async {
643                let stream = stream::repeat_with(|| Event::default().data("msg"))
644                    .map(Ok::<_, Infallible>)
645                    .throttle(DELAY)
646                    .take(2);
647
648                Sse::new(stream).keep_alive(
649                    KeepAlive::new()
650                        .interval(Duration::from_secs(1))
651                        .text("keep-alive-text"),
652                )
653            }),
654        );
655
656        let client = TestClient::new(app);
657        let mut stream = client.get("/").await;
658
659        // first message should be an event
660        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
661        assert_eq!(event_fields.get("data").unwrap(), "msg");
662
663        // then 4 seconds of keep-alive messages
664        for _ in 0..4 {
665            tokio::time::sleep(Duration::from_secs(1)).await;
666            let event_fields = parse_event(&stream.chunk_text().await.unwrap());
667            assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
668        }
669
670        // then the last event
671        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
672        assert_eq!(event_fields.get("data").unwrap(), "msg");
673
674        // then no more events or keep-alive messages
675        assert!(stream.chunk_text().await.is_none());
676    }
677
678    fn parse_event(payload: &str) -> HashMap<String, String> {
679        let mut fields = HashMap::new();
680
681        let mut lines = payload.lines().peekable();
682        while let Some(line) = lines.next() {
683            if line.is_empty() {
684                assert!(lines.next().is_none());
685                break;
686            }
687
688            let (mut key, value) = line.split_once(':').unwrap();
689            let value = value.trim();
690            if key.is_empty() {
691                key = "comment";
692            }
693            fields.insert(key.to_owned(), value.to_owned());
694        }
695
696        fields
697    }
698
699    #[test]
700    fn memchr_splitting() {
701        assert_eq!(
702            memchr_split(2, &[]).collect::<Vec<_>>(),
703            [&[]] as [&[u8]; 1]
704        );
705        assert_eq!(
706            memchr_split(2, &[2]).collect::<Vec<_>>(),
707            [&[], &[]] as [&[u8]; 2]
708        );
709        assert_eq!(
710            memchr_split(2, &[1]).collect::<Vec<_>>(),
711            [&[1]] as [&[u8]; 1]
712        );
713        assert_eq!(
714            memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
715            [&[1], &[]] as [&[u8]; 2]
716        );
717        assert_eq!(
718            memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
719            [&[], &[1]] as [&[u8]; 2]
720        );
721        assert_eq!(
722            memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
723            [&[1], &[], &[1]] as [&[u8]; 3]
724        );
725    }
726}