Skip to content

Commit

Permalink
Add OAuthSession
Browse files Browse the repository at this point in the history
  • Loading branch information
sugyan committed Nov 13, 2024
1 parent 6734492 commit d041ae7
Show file tree
Hide file tree
Showing 14 changed files with 229 additions and 40 deletions.
2 changes: 1 addition & 1 deletion atrium-api/src/agent/atp_agent/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ where
fn base_uri(&self) -> String {
self.store.get_endpoint()
}
async fn authentication_token(&self, is_refresh: bool) -> Option<String> {
async fn authorization_token(&self, is_refresh: bool) -> Option<String> {
self.store.get_session().await.map(|session| {
if is_refresh {
session.data.refresh_jwt
Expand Down
3 changes: 2 additions & 1 deletion atrium-oauth/oauth-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ keywords = ["atproto", "bluesky", "oauth"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
atrium-api = { workspace = true, default-features = false }
atrium-api = { workspace = true, features = ["agent"] }
atrium-identity.workspace = true
atrium-xrpc.workspace = true
base64.workspace = true
Expand All @@ -34,6 +34,7 @@ thiserror.workspace = true
trait-variant.workspace = true

[dev-dependencies]
atrium-api = { workspace = true, features = ["bluesky"] }
hickory-resolver.workspace = true
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }

Expand Down
24 changes: 21 additions & 3 deletions atrium-oauth/oauth-client/examples/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use atrium_api::agent::Agent;
use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL};
use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver};
use atrium_oauth_client::store::state::MemoryStateStore;
Expand Down Expand Up @@ -61,7 +62,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.authorize(
std::env::var("HANDLE").unwrap_or(String::from("https://bsky.social")),
AuthorizeOptions {
scopes: Some(vec![String::from("atproto")]),
scopes: Some(vec![String::from("atproto"), String::from("transition:generic")]),
..Default::default()
}
)
Expand All @@ -78,7 +79,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let uri = url.trim().parse::<Uri>()?;
let params = serde_html_form::from_str(uri.query().unwrap())?;
println!("{}", serde_json::to_string_pretty(&client.callback(params).await?)?);

let (session, _) = client.callback(params).await?;
let agent = Agent::new(session);
println!(
"{:?}",
agent
.api
.app
.bsky
.feed
.get_timeline(
atrium_api::app::bsky::feed::get_timeline::ParametersData {
algorithm: None,
cursor: None,
limit: 1.try_into().ok()
}
.into()
)
.await?
);
Ok(())
}
2 changes: 1 addition & 1 deletion atrium-oauth/oauth-client/src/atproto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl TryIntoOAuthClientMetadata for AtprotoLocalhostClientMetadata {
return Err(Error::EmptyRedirectUris);
}
Ok(OAuthClientMetadata {
client_id: String::from("http://localhost"),
client_id: String::from("http://localhost?scope=atproto+transition:generic"), // TODO
client_uri: None,
redirect_uris: self.redirect_uris,
scope: None, // will be set to `atproto`
Expand Down
6 changes: 4 additions & 2 deletions atrium-oauth/oauth-client/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ pub enum Error {
#[error(transparent)]
ClientMetadata(#[from] crate::atproto::Error),
#[error(transparent)]
Keyset(#[from] crate::keyset::Error),
Dpop(#[from] crate::http_client::dpop::Error),
#[error(transparent)]
Identity(#[from] atrium_identity::Error),
Keyset(#[from] crate::keyset::Error),
#[error(transparent)]
ServerAgent(#[from] crate::server_agent::Error),
#[error(transparent)]
Identity(#[from] atrium_identity::Error),
#[error("authorize error: {0}")]
Authorize(String),
#[error("callback error: {0}")]
Expand Down
64 changes: 46 additions & 18 deletions atrium-oauth/oauth-client/src/http_client/dpop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use jose_jwk::{crypto, EcCurves, Jwk, Key};
use rand::rngs::SmallRng;
use rand::{RngCore, SeedableRng};
use serde::Deserialize;
use sha2::{Digest, Sha256};
use std::sync::Arc;
use thiserror::Error;

Expand Down Expand Up @@ -44,13 +45,15 @@ where
#[allow(dead_code)]
iss: String,
nonces: S,
is_auth_server: bool,
}

impl<T> DpopClient<T> {
pub fn new(
key: Key,
iss: String,
http_client: Arc<T>,
is_auth_server: bool,
supported_algs: &Option<Vec<String>>,
) -> Result<Self> {
if let Some(algs) = supported_algs {
Expand All @@ -66,9 +69,21 @@ impl<T> DpopClient<T> {
}
}
let nonces = MemorySimpleStore::<String, String>::default();
Ok(Self { inner: http_client, key, iss, nonces })
Ok(Self { inner: http_client, key, iss, nonces, is_auth_server })
}
fn build_proof(&self, htm: String, htu: String, nonce: Option<String>) -> Result<String> {
}

impl<T, S> DpopClient<T, S>
where
S: SimpleStore<String, String>,
{
fn build_proof(
&self,
htm: String,
htu: String,
ath: Option<String>,
nonce: Option<String>,
) -> Result<String> {
match crypto::Key::try_from(&self.key).map_err(Error::JwkCrypto)? {
crypto::Key::P256(crypto::Kind::Secret(secret_key)) => {
let mut header = RegisteredHeader::from(Algorithm::Signing(Signing::Es256));
Expand All @@ -83,27 +98,32 @@ impl<T> DpopClient<T> {
iat: Some(Utc::now().timestamp()),
..Default::default()
},
public: PublicClaims {
htm: Some(htm),
htu: Some(htu),
nonce,
..Default::default()
},
public: PublicClaims { htm: Some(htm), htu: Some(htu), ath, nonce },
};
Ok(create_signed_jwt(secret_key.into(), header.into(), claims)?)
}
_ => unimplemented!(),
}
}
fn is_use_dpop_nonce_error(&self, response: &Response<Vec<u8>>) -> bool {
// is auth server?
if response.status() == 400 {
if let Ok(res) = serde_json::from_slice::<ErrorResponse>(response.body()) {
return res.error == "use_dpop_nonce";
};
// https://datatracker.ietf.org/doc/html/rfc9449#name-authorization-server-provid
if self.is_auth_server {
if response.status() == 400 {
if let Ok(res) = serde_json::from_slice::<ErrorResponse>(response.body()) {
return res.error == "use_dpop_nonce";
};
}
}
// https://datatracker.ietf.org/doc/html/rfc6750#section-3
// https://datatracker.ietf.org/doc/html/rfc9449#name-resource-server-provided-no
else if response.status() == 401 {
if let Some(www_auth) =
response.headers().get("WWW-Authenticate").and_then(|v| v.to_str().ok())
{
return www_auth.starts_with("DPoP")
&& www_auth.contains(r#"error="use_dpop_nonce""#);
}
}
// is resource server?

false
}
// https://datatracker.ietf.org/doc/html/rfc9449#section-4.2
Expand All @@ -115,9 +135,10 @@ impl<T> DpopClient<T> {
}
}

impl<T> HttpClient for DpopClient<T>
impl<T, S> HttpClient for DpopClient<T, S>
where
T: HttpClient + Send + Sync + 'static,
S: SimpleStore<String, String> + Send + Sync + 'static,
{
async fn send_http(
&self,
Expand All @@ -128,9 +149,16 @@ where
let nonce_key = uri.authority().unwrap().to_string();
let htm = request.method().to_string();
let htu = uri.to_string();
// https://datatracker.ietf.org/doc/html/rfc9449#section-4.2
let ath = request
.headers()
.get("Authorization")
.filter(|v| v.to_str().map_or(false, |s| s.starts_with("DPoP ")))
.map(|auth| URL_SAFE_NO_PAD.encode(Sha256::digest(&auth.as_bytes()[5..])));

let init_nonce = self.nonces.get(&nonce_key).await?;
let init_proof = self.build_proof(htm.clone(), htu.clone(), init_nonce.clone())?;
let init_proof =
self.build_proof(htm.clone(), htu.clone(), ath.clone(), init_nonce.clone())?;
request.headers_mut().insert("DPoP", init_proof.parse()?);
let response = self.inner.send_http(request.clone()).await?;

Expand All @@ -151,7 +179,7 @@ where
if !self.is_use_dpop_nonce_error(&response) {
return Ok(response);
}
let next_proof = self.build_proof(htm, htu, next_nonce)?;
let next_proof = self.build_proof(htm, htu, ath, next_nonce)?;
request.headers_mut().insert("DPoP", next_proof.parse()?);
let response = self.inner.send_http(request).await?;
Ok(response)
Expand Down
2 changes: 2 additions & 0 deletions atrium-oauth/oauth-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod http_client;
mod jose;
mod keyset;
mod oauth_client;
mod oauth_session;
mod resolver;
mod server_agent;
pub mod store;
Expand All @@ -19,6 +20,7 @@ pub use error::{Error, Result};
pub use http_client::default::DefaultHttpClient;
pub use http_client::dpop::DpopClient;
pub use oauth_client::{OAuthClient, OAuthClientConfig};
pub use oauth_session::OAuthSession;
pub use resolver::OAuthResolverConfig;
pub use types::{
AuthorizeOptionPrompt, AuthorizeOptions, CallbackParams, OAuthClientMetadata, TokenSet,
Expand Down
38 changes: 32 additions & 6 deletions atrium-oauth/oauth-client/src/oauth_client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::constants::FALLBACK_ALG;
use crate::error::{Error, Result};
use crate::http_client::dpop::{DpopClient, Error as DpopError};
use crate::keyset::Keyset;
use crate::oauth_session::OAuthSession;
use crate::resolver::{OAuthResolver, OAuthResolverConfig};
use crate::server_agent::{OAuthRequest, OAuthServerAgent};
use crate::store::state::{InternalStateData, StateStore};
Expand Down Expand Up @@ -155,6 +157,7 @@ where
iss: metadata.issuer.clone(),
dpop_key: dpop_key.clone(),
verifier,
app_state: options.state,
};
self.state_store
.set(state.clone(), state_data)
Expand Down Expand Up @@ -207,7 +210,10 @@ where
todo!()
}
}
pub async fn callback(&self, params: CallbackParams) -> Result<TokenSet> {
pub async fn callback(
&self,
params: CallbackParams,
) -> Result<(OAuthSession<T>, Option<String>)> {
let Some(state_key) = params.state else {
return Err(Error::Callback("missing `state` parameter".into()));
};
Expand Down Expand Up @@ -241,9 +247,15 @@ where
self.keyset.clone(),
)?;
let token_set = server.exchange_code(&params.code, &state.verifier).await?;
// TODO: store token_set to session store

// TODO: create session?
Ok(token_set)
let session = self.create_session(
state.dpop_key.clone(),
&metadata,
&self.client_metadata,
token_set,
)?;
Ok((session, state.app_state))
}
fn generate_dpop_key(metadata: &OAuthAuthorizationServerMetadata) -> Option<Key> {
let mut algs =
Expand All @@ -255,8 +267,22 @@ where
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
let verifier =
URL_SAFE_NO_PAD.encode(get_random_values::<_, 32>(&mut ThreadRng::default()));
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
(URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes())), verifier)
(URL_SAFE_NO_PAD.encode(Sha256::digest(&verifier)), verifier)
}
fn create_session(
&self,
dpop_key: Key,
server_metadata: &OAuthAuthorizationServerMetadata,
client_metadata: &OAuthClientMetadata,
token_set: TokenSet,
) -> core::result::Result<OAuthSession<T>, DpopError> {
let dpop_client = DpopClient::new(
dpop_key,
client_metadata.client_id.clone(),
self.http_client.clone(),
false,
&server_metadata.token_endpoint_auth_signing_alg_values_supported,
)?;
Ok(OAuthSession::new(dpop_client, token_set))
}
}
85 changes: 85 additions & 0 deletions atrium-oauth/oauth-client/src/oauth_session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use crate::store::{memory::MemorySimpleStore, SimpleStore};
use crate::{DpopClient, TokenSet};
use atrium_api::{agent::SessionManager, types::string::Did};
use atrium_xrpc::types::AuthorizationType;
use atrium_xrpc::{
http::{Request, Response},
HttpClient, XrpcClient,
};

pub struct OAuthSession<T, S = MemorySimpleStore<String, String>>
where
S: SimpleStore<String, String>,
{
inner: DpopClient<T, S>,
token_set: TokenSet, // TODO: replace with a session store?
}

impl<T, S> OAuthSession<T, S>
where
S: SimpleStore<String, String> + Send + Sync + 'static,
{
pub fn new(dpop_client: DpopClient<T, S>, token_set: TokenSet) -> Self {
Self { inner: dpop_client, token_set }
}
}

impl<T, S> HttpClient for OAuthSession<T, S>
where
T: HttpClient + Send + Sync + 'static,
S: SimpleStore<String, String> + Send + Sync + 'static,
{
async fn send_http(
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
self.inner.send_http(request).await
}
}

impl<T, S> XrpcClient for OAuthSession<T, S>
where
T: HttpClient + Send + Sync + 'static,
S: SimpleStore<String, String> + Send + Sync + 'static,
{
fn base_uri(&self) -> String {
self.token_set.aud.clone()
}
fn authorization_type(&self) -> AuthorizationType {
AuthorizationType::Dpop
}
async fn authorization_token(&self, is_refresh: bool) -> Option<String> {

Check failure on line 51 in atrium-oauth/oauth-client/src/oauth_session.rs

View workflow job for this annotation

GitHub Actions / Rust (1.75.0)

unused variable: `is_refresh`

Check failure on line 51 in atrium-oauth/oauth-client/src/oauth_session.rs

View workflow job for this annotation

GitHub Actions / Rust (stable)

unused variable: `is_refresh`
Some(self.token_set.access_token.clone())
}
// async fn atproto_proxy_header(&self) -> Option<String> {
// todo!()
// }
// async fn atproto_accept_labelers_header(&self) -> Option<Vec<String>> {
// todo!()
// }
// async fn send_xrpc<P, I, O, E>(
// &self,
// request: &XrpcRequest<P, I>,
// ) -> Result<OutputDataOrBytes<O>, Error<E>>
// where
// P: Serialize + Send + Sync,
// I: Serialize + Send + Sync,
// O: DeserializeOwned + Send + Sync,
// E: DeserializeOwned + Send + Sync + Debug,
// {
// todo!()
// }
}

impl<T, S> SessionManager for OAuthSession<T, S>
where
T: HttpClient + Send + Sync + 'static,
S: SimpleStore<String, String> + Send + Sync + 'static,
{
async fn did(&self) -> Option<Did> {
todo!()
}
}

#[cfg(test)]
mod tests {}
Loading

0 comments on commit d041ae7

Please sign in to comment.