Skip to content

Commit

Permalink
feat: validate variable data locations (#149)
Browse files Browse the repository at this point in the history
* feat: add VarKind

* feat: validate variable data locations

* tests

* note

* tests
  • Loading branch information
DaniPopes authored Nov 28, 2024
1 parent 3a9c5c5 commit 57071d7
Show file tree
Hide file tree
Showing 28 changed files with 897 additions and 214 deletions.
8 changes: 8 additions & 0 deletions crates/ast/src/ast/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,14 @@ impl DataLocation {
Self::Calldata => "calldata",
}
}

/// Returns the string representation of the storage location, or `"none"` if `None`.
pub const fn opt_to_str(this: Option<Self>) -> &'static str {
match this {
Some(location) => location.to_str(),
None => "none",
}
}
}

// How a function can mutate the EVM state.
Expand Down
6 changes: 6 additions & 0 deletions crates/ast/src/ast/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ impl ElementaryType {
| Self::FixedBytes(..)
)
}

/// Returns `true` if the type is a reference type.
#[inline]
pub const fn is_reference_type(self) -> bool {
matches!(self, Self::String | Self::Bytes)
}
}

/// Byte size of a fixed-bytes, integer, or fixed-point number (M) type. Valid values: 0..=32.
Expand Down
82 changes: 82 additions & 0 deletions crates/data-structures/src/fmt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use std::{cell::Cell, fmt};

/// Wrapper for [`fmt::from_fn`].
#[cfg(feature = "nightly")]
pub fn from_fn<F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result>(
f: F,
) -> impl fmt::Debug + fmt::Display {
fmt::from_fn(f)
}

/// Polyfill for [`fmt::from_fn`].
#[cfg(not(feature = "nightly"))]
pub fn from_fn<F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result>(
f: F,
) -> impl fmt::Debug + fmt::Display {
struct FromFn<F>(F);

impl<F> fmt::Debug for FromFn<F>
where
F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(self.0)(f)
}
}

impl<F> fmt::Display for FromFn<F>
where
F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(self.0)(f)
}
}

FromFn(f)
}

/// Returns `list` formatted as a comma-separated list with "or" before the last item.
pub fn or_list<I>(list: I) -> impl fmt::Display
where
I: IntoIterator<IntoIter: ExactSizeIterator, Item: fmt::Display>,
{
let list = Cell::new(Some(list.into_iter()));
from_fn(move |f| {
let list = list.take().expect("or_list called twice");
let len = list.len();
for (i, t) in list.enumerate() {
if i > 0 {
let is_last = i == len - 1;
f.write_str(if len > 2 && is_last {
", or "
} else if len == 2 && is_last {
" or "
} else {
", "
})?;
}
write!(f, "{t}")?;
}
Ok(())
})
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_or_list() {
let tests: &[(&[&str], &str)] = &[
(&[], ""),
(&["`<eof>`"], "`<eof>`"),
(&["integer", "identifier"], "integer or identifier"),
(&["path", "string literal", "`&&`"], "path, string literal, or `&&`"),
(&["`&&`", "`||`", "`&&`", "`||`"], "`&&`, `||`, `&&`, or `||`"),
];
for &(tokens, expected) in tests {
assert_eq!(or_list(tokens).to_string(), expected, "{tokens:?}");
}
}
}
39 changes: 1 addition & 38 deletions crates/data-structures/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
#![cfg_attr(feature = "nightly", feature(rustc_attrs))]
#![cfg_attr(feature = "nightly", allow(internal_features))]

use std::fmt;

pub mod cycle;
pub mod fmt;
pub mod hint;
pub mod index;
pub mod map;
Expand Down Expand Up @@ -42,39 +41,3 @@ pub use smallvec;
pub fn outline<R>(f: impl FnOnce() -> R) -> R {
f()
}

/// Wrapper for [`fmt::from_fn`].
#[cfg(feature = "nightly")]
pub fn fmt_from_fn<F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result>(
f: F,
) -> impl fmt::Debug + fmt::Display {
fmt::from_fn(f)
}

/// Polyfill for [`fmt::from_fn`];
#[cfg(not(feature = "nightly"))]
pub fn fmt_from_fn<F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result>(
f: F,
) -> impl fmt::Debug + fmt::Display {
struct FromFn<F>(F);

impl<F> fmt::Debug for FromFn<F>
where
F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(self.0)(f)
}
}

impl<F> fmt::Display for FromFn<F>
where
F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(self.0)(f)
}
}

FromFn(f)
}
10 changes: 2 additions & 8 deletions crates/parse/src/parser/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -738,12 +738,7 @@ impl<'sess, 'ast> Parser<'sess, 'ast> {
let mut indexed = false;
loop {
if let Some(s) = self.parse_data_location() {
let transient = matches!(s, DataLocation::Transient);
let transient_allowed = flags.contains(VarFlags::TRANSIENT);
if transient && !transient_allowed {
let msg = "`transient` data location is not allowed here";
self.dcx().err(msg).span(self.prev_token.span).emit();
} else if !(transient || flags.contains(VarFlags::DATALOC)) {
if !flags.contains(VarFlags::DATALOC) {
let msg = "data locations are not allowed here";
self.dcx().err(msg).span(self.prev_token.span).emit();
} else if data_location.is_some() {
Expand Down Expand Up @@ -984,7 +979,6 @@ bitflags::bitflags! {
pub(super) struct VarFlags: u16 {
// `ty` is always required. `name` is always optional, unless `NAME` is specified.

const TRANSIENT = 1 << 0;
const DATALOC = 1 << 1;
const INDEXED = 1 << 2;

Expand Down Expand Up @@ -1015,7 +1009,7 @@ bitflags::bitflags! {
const FUNCTION_TY = Self::DATALOC.bits() | Self::NAME_WARN.bits();

// https://docs.soliditylang.org/en/latest/grammar.html#a4.SolidityParser.stateVariableDeclaration
const STATE_VAR = Self::TRANSIENT.bits()
const STATE_VAR = Self::DATALOC.bits()
| Self::PRIVATE.bits()
| Self::INTERNAL.bits()
| Self::PUBLIC.bits()
Expand Down
47 changes: 3 additions & 44 deletions crates/parse/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@ use solar_ast::{
token::{Delimiter, Token, TokenKind},
AstPath, Box, DocComment, DocComments, PathSlice,
};
use solar_data_structures::BumpExt;
use solar_data_structures::{fmt::or_list, BumpExt};
use solar_interface::{
diagnostics::DiagCtxt,
source_map::{FileName, SourceFile},
Ident, Result, Session, Span, Symbol,
};
use std::{
fmt::{self, Write},
path::Path,
};
use std::{fmt, path::Path};

mod expr;
mod item;
Expand Down Expand Up @@ -79,7 +76,7 @@ impl fmt::Display for ExpectedToken {

impl ExpectedToken {
fn to_string_many(tokens: &[Self]) -> String {
or_list(tokens)
or_list(tokens).to_string()
}

fn eq_kind(&self, other: &TokenKind) -> bool {
Expand Down Expand Up @@ -978,41 +975,3 @@ impl<'sess, 'ast> Parser<'sess, 'ast> {
self.expected_ident_found(false).unwrap_err()
}
}

fn or_list<T: fmt::Display>(list: &[T]) -> String {
let len = list.len();
let mut s = String::with_capacity(16 * len);
for (i, t) in list.iter().enumerate() {
if i > 0 {
let is_last = i == len - 1;
s.push_str(if len > 2 && is_last {
", or "
} else if len == 2 && is_last {
" or "
} else {
", "
});
}
let _ = write!(s, "{t}");
}
s
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_or_list() {
let tests: &[(&[&str], &str)] = &[
(&[], ""),
(&["`<eof>`"], "`<eof>`"),
(&["integer", "identifier"], "integer or identifier"),
(&["path", "string literal", "`&&`"], "path, string literal, or `&&`"),
(&["`&&`", "`||`", "`&&`", "`||`"], "`&&`, `||`, `&&`, or `||`"),
];
for &(tokens, expected) in tests {
assert_eq!(or_list(tokens), expected, "{tokens:?}");
}
}
}
29 changes: 22 additions & 7 deletions crates/sema/src/ast_lowering/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,14 @@ impl<'ast> super::LoweringContext<'_, 'ast, '_> {
}
ast::ItemKind::Contract(i) => hir::ItemId::Contract(self.lower_contract(item, i)),
ast::ItemKind::Function(i) => hir::ItemId::Function(self.lower_function(item, i)),
ast::ItemKind::Variable(i) => hir::ItemId::Variable(self.lower_variable(i)),
ast::ItemKind::Variable(i) => {
let kind = if self.current_contract_id.is_some() {
hir::VarKind::State
} else {
hir::VarKind::Global
};
hir::ItemId::Variable(self.lower_variable(i, kind))
}
ast::ItemKind::Struct(i) => hir::ItemId::Struct(self.lower_struct(item, i)),
ast::ItemKind::Enum(i) => hir::ItemId::Enum(self.lower_enum(item, i)),
ast::ItemKind::Udvt(i) => hir::ItemId::Udvt(self.lower_udvt(item, i)),
Expand Down Expand Up @@ -154,13 +161,18 @@ impl<'ast> super::LoweringContext<'_, 'ast, '_> {
})
}

fn lower_variable(&mut self, i: &ast::VariableDefinition<'_>) -> hir::VariableId {
fn lower_variable(
&mut self,
i: &ast::VariableDefinition<'_>,
kind: hir::VarKind,
) -> hir::VariableId {
lower_variable_partial(
&mut self.hir,
i,
self.current_source_id,
self.current_contract_id,
self.current_contract_id.is_some(),
None,
kind,
)
}

Expand Down Expand Up @@ -231,7 +243,8 @@ pub(super) fn lower_variable_partial(
i: &ast::VariableDefinition<'_>,
source: SourceId,
contract: Option<ContractId>,
is_state_variable: bool,
function: Option<hir::FunctionId>,
kind: hir::VarKind,
) -> hir::VariableId {
// handled later: ty, override_, initializer
let ast::VariableDefinition {
Expand All @@ -248,7 +261,9 @@ pub(super) fn lower_variable_partial(
let id = hir.variables.push(hir::Variable {
source,
contract,
function,
span,
kind,
ty: hir::Type::DUMMY,
name,
visibility,
Expand All @@ -258,7 +273,6 @@ pub(super) fn lower_variable_partial(
overrides: &[],
indexed,
initializer: None,
is_state_variable,
getter: None,
});
let v = hir.variable(id);
Expand All @@ -272,7 +286,9 @@ fn generate_partial_getter(hir: &mut hir::Hir<'_>, id: hir::VariableId) -> hir::
let hir::Variable {
source,
contract,
function: _,
span,
kind,
ty: _,
name,
visibility,
Expand All @@ -282,13 +298,12 @@ fn generate_partial_getter(hir: &mut hir::Hir<'_>, id: hir::VariableId) -> hir::
overrides,
indexed,
initializer: _,
is_state_variable,
getter,
} = *hir.variable(id);
debug_assert!(!indexed);
debug_assert!(data_location.is_none());
debug_assert_eq!(visibility, Some(ast::Visibility::Public));
debug_assert!(is_state_variable);
debug_assert!(kind.is_state());
debug_assert!(getter.is_none());
hir.functions.push(hir::Function {
source,
Expand Down
Loading

0 comments on commit 57071d7

Please sign in to comment.