hyper_util/client/legacy/connect/proxy/
tunnel.rs1use std::error::Error as StdError;
2use std::future::Future;
3use std::marker::{PhantomData, Unpin};
4use std::pin::Pin;
5use std::task::{self, Poll};
6
7use futures_core::ready;
8use http::{HeaderMap, HeaderValue, Uri};
9use hyper::rt::{Read, Write};
10use pin_project_lite::pin_project;
11use tower_service::Service;
12
13#[derive(Debug)]
19pub struct Tunnel<C> {
20 headers: Headers,
21 inner: C,
22 proxy_dst: Uri,
23}
24
25#[derive(Clone, Debug)]
26enum Headers {
27 Empty,
28 Auth(HeaderValue),
29 Extra(HeaderMap),
30}
31
32#[derive(Debug)]
33pub enum TunnelError {
34 ConnectFailed(Box<dyn StdError + Send + Sync>),
35 Io(std::io::Error),
36 MissingHost,
37 ProxyAuthRequired,
38 ProxyHeadersTooLong,
39 TunnelUnexpectedEof,
40 TunnelUnsuccessful,
41}
42
43pin_project! {
44 #[must_use = "futures do nothing unless polled"]
50 #[allow(missing_debug_implementations)]
51 pub struct Tunneling<F, T> {
52 #[pin]
53 fut: BoxTunneling<T>,
54 _marker: PhantomData<F>,
55 }
56}
57
58type BoxTunneling<T> = Pin<Box<dyn Future<Output = Result<T, TunnelError>> + Send>>;
59
60impl<C> Tunnel<C> {
61 pub fn new(proxy_dst: Uri, connector: C) -> Self {
70 Self {
71 headers: Headers::Empty,
72 inner: connector,
73 proxy_dst,
74 }
75 }
76
77 pub fn with_auth(mut self, mut auth: HeaderValue) -> Self {
79 auth.set_sensitive(true);
81 match self.headers {
82 Headers::Empty => {
83 self.headers = Headers::Auth(auth);
84 }
85 Headers::Auth(ref mut existing) => {
86 *existing = auth;
87 }
88 Headers::Extra(ref mut extra) => {
89 extra.insert(http::header::PROXY_AUTHORIZATION, auth);
90 }
91 }
92
93 self
94 }
95
96 pub fn with_headers(mut self, mut headers: HeaderMap) -> Self {
100 match self.headers {
101 Headers::Empty => {
102 self.headers = Headers::Extra(headers);
103 }
104 Headers::Auth(auth) => {
105 headers
106 .entry(http::header::PROXY_AUTHORIZATION)
107 .or_insert(auth);
108 self.headers = Headers::Extra(headers);
109 }
110 Headers::Extra(ref mut extra) => {
111 extra.extend(headers);
112 }
113 }
114
115 self
116 }
117}
118
119impl<C> Service<Uri> for Tunnel<C>
120where
121 C: Service<Uri>,
122 C::Future: Send + 'static,
123 C::Response: Read + Write + Unpin + Send + 'static,
124 C::Error: Into<Box<dyn StdError + Send + Sync>>,
125{
126 type Response = C::Response;
127 type Error = TunnelError;
128 type Future = Tunneling<C::Future, C::Response>;
129
130 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
131 ready!(self.inner.poll_ready(cx)).map_err(|e| TunnelError::ConnectFailed(e.into()))?;
132 Poll::Ready(Ok(()))
133 }
134
135 fn call(&mut self, dst: Uri) -> Self::Future {
136 let connecting = self.inner.call(self.proxy_dst.clone());
137 let headers = self.headers.clone();
138
139 Tunneling {
140 fut: Box::pin(async move {
141 let conn = connecting
142 .await
143 .map_err(|e| TunnelError::ConnectFailed(e.into()))?;
144 tunnel(
145 conn,
146 dst.host().ok_or(TunnelError::MissingHost)?,
147 dst.port().map(|p| p.as_u16()).unwrap_or(443),
148 &headers,
149 )
150 .await
151 }),
152 _marker: PhantomData,
153 }
154 }
155}
156
157impl<F, T, E> Future for Tunneling<F, T>
158where
159 F: Future<Output = Result<T, E>>,
160{
161 type Output = Result<T, TunnelError>;
162
163 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
164 self.project().fut.poll(cx)
165 }
166}
167
168async fn tunnel<T>(mut conn: T, host: &str, port: u16, headers: &Headers) -> Result<T, TunnelError>
169where
170 T: Read + Write + Unpin,
171{
172 let mut buf = format!(
173 "\
174 CONNECT {host}:{port} HTTP/1.1\r\n\
175 Host: {host}:{port}\r\n\
176 "
177 )
178 .into_bytes();
179
180 match headers {
181 Headers::Auth(auth) => {
182 buf.extend_from_slice(b"Proxy-Authorization: ");
183 buf.extend_from_slice(auth.as_bytes());
184 buf.extend_from_slice(b"\r\n");
185 }
186 Headers::Extra(extra) => {
187 for (name, value) in extra {
188 buf.extend_from_slice(name.as_str().as_bytes());
189 buf.extend_from_slice(b": ");
190 buf.extend_from_slice(value.as_bytes());
191 buf.extend_from_slice(b"\r\n");
192 }
193 }
194 Headers::Empty => (),
195 }
196
197 buf.extend_from_slice(b"\r\n");
199
200 crate::rt::write_all(&mut conn, &buf)
201 .await
202 .map_err(TunnelError::Io)?;
203
204 let mut buf = [0; 8192];
205 let mut pos = 0;
206
207 loop {
208 let n = crate::rt::read(&mut conn, &mut buf[pos..])
209 .await
210 .map_err(TunnelError::Io)?;
211
212 if n == 0 {
213 return Err(TunnelError::TunnelUnexpectedEof);
214 }
215 pos += n;
216
217 let recvd = &buf[..pos];
218 if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
219 if recvd.ends_with(b"\r\n\r\n") {
220 return Ok(conn);
221 }
222 if pos == buf.len() {
223 return Err(TunnelError::ProxyHeadersTooLong);
224 }
225 } else if recvd.starts_with(b"HTTP/1.1 407") {
227 return Err(TunnelError::ProxyAuthRequired);
228 } else {
229 return Err(TunnelError::TunnelUnsuccessful);
230 }
231 }
232}
233
234impl std::fmt::Display for TunnelError {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 f.write_str("tunnel error: ")?;
237
238 f.write_str(match self {
239 TunnelError::MissingHost => "missing destination host",
240 TunnelError::ProxyAuthRequired => "proxy authorization required",
241 TunnelError::ProxyHeadersTooLong => "proxy response headers too long",
242 TunnelError::TunnelUnexpectedEof => "unexpected end of file",
243 TunnelError::TunnelUnsuccessful => "unsuccessful",
244 TunnelError::ConnectFailed(_) => "failed to create underlying connection",
245 TunnelError::Io(_) => "io error establishing tunnel",
246 })
247 }
248}
249
250impl std::error::Error for TunnelError {
251 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
252 match self {
253 TunnelError::Io(ref e) => Some(e),
254 TunnelError::ConnectFailed(ref e) => Some(&**e),
255 _ => None,
256 }
257 }
258}