Skip to content

Commit

Permalink
Modified ShaderWriter::pushScope to push a compound statement instead…
Browse files Browse the repository at this point in the history
… of a container.
  • Loading branch information
DragonJoker committed Jul 11, 2024
1 parent dd07d78 commit 48fe760
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 59 deletions.
2 changes: 1 addition & 1 deletion source/ShaderWriter/Writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ namespace sdw

void ShaderWriter::pushScope()
{
m_builder->pushScope( getStmtCache().makeContainer() );
m_builder->pushScope( getStmtCache().makeCompound() );
}

void ShaderWriter::popScope()
Expand Down
14 changes: 7 additions & 7 deletions test/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,22 +668,22 @@ namespace test
}
}

void reportFailure( char const * const error
, char const * const function
void reportFailure( std::string_view error
, std::string_view function
, int line
, TestCounts & testCounts )
{
testCounts.reportFailure( error, function, line );
testCounts.reportFailure( error.data(), function.data(), line );
}

void reportFailure( char const * const error
, char const * const callerFunction
void reportFailure( std::string_view error
, std::string_view callerFunction
, int callerLine
, char const * const calleeFunction
, std::string_view calleeFunction
, int calleeLine
, TestCounts & testCounts )
{
testCounts.reportFailure( error, callerFunction, callerLine, calleeFunction, calleeLine );
testCounts.reportFailure( error.data(), callerFunction.data(), callerLine, calleeFunction.data(), calleeLine );
}

//*********************************************************************************************
Expand Down
99 changes: 49 additions & 50 deletions test/Common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,13 +538,40 @@ namespace test
class Exception
: public std::runtime_error
{
using std::runtime_error::runtime_error;

public:
explicit Exception( char const * const message
, char const * const function
, int32_t line )
: std::runtime_error{ message }
, m_function{ function }
, m_line{ line }
{
}

explicit Exception( std::string const & message
, char const * const function
, int32_t line )
: Exception{ message.c_str(), function, line }
{
}

std::string getText()const
{
return what();
}

std::string getFunction()const
{
return m_function;
}

int32_t getLine()const
{
return m_line;
}

std::string m_function;
int32_t m_line;
};

struct TestStringStreams
Expand Down Expand Up @@ -789,45 +816,17 @@ namespace test
void beginTest( TestCounts & testCounts
, std::string name );
void endTest( TestCounts & testCounts );
void reportFailure( char const * const error
, char const * const function
void reportFailure( std::string_view error
, std::string_view function
, int line
, TestCounts & testCounts );
void reportFailure( char const * const error
, char const * const callerFunction
void reportFailure( std::string_view error
, std::string_view callerFunction
, int callerLine
, char const * const calleeFunction
, std::string_view calleeFunction
, int calleeLine
, TestCounts & testCounts );

inline void reportFailure( std::string_view error
, char const * const function
, int line
, TestCounts & testCounts )
{
reportFailure( error.data(), function, line, testCounts );
}

inline void reportFailure( std::string_view error
, char const * const callerFunction
, int callerLine
, char const * const calleeFunction
, int calleeLine
, TestCounts & testCounts )
{
reportFailure( error.data(), callerFunction, callerLine, calleeFunction, calleeLine, testCounts );
}

inline void reportFailure( std::string_view error
, std::string const & callerFunction
, int callerLine
, char const * const calleeFunction
, int calleeLine
, TestCounts & testCounts )
{
reportFailure( error.data(), callerFunction.c_str(), callerLine, calleeFunction, calleeLine, testCounts );
}

# define astTestSuiteMain( testName )\
static test::TestResults launch##testName( test::TestSuite & suite, test::TestCounts & testCounts )

Expand Down Expand Up @@ -945,13 +944,13 @@ namespace test
testCounts.incTest();\
if ( !( x ) )\
{\
throw test::Exception{ "\n Value: " + toString( x ) };\
throw test::Exception{ testConcatStr2( x, " failed." ), __FUNCTION__, __LINE__ };\
}\
testCounts.flushErrors();\
}\
catch ( test::Exception & exc )\
{\
test::reportFailure( testConcatStr2( x, " failed:" ) + exc.getText(), __FUNCTION__, __LINE__, testCounts );\
test::reportFailure( testConcatStr2( x, " failed:" ) + exc.getText(), exc.getFunction(), exc.getLine(), testCounts );\
}\
catch ( ... )\
{\
Expand All @@ -964,19 +963,19 @@ namespace test
testCounts.incTest();\
if ( !( x ) )\
{\
throw test::Exception{ "\n Value: " + toString( x ) };\
throw test::Exception{ testConcatStr2( x, " failed." ), __FUNCTION__, __LINE__ };\
}

#define astEndRequire\
testCounts.flushErrors();\
}\
catch ( test::Exception & exc )\
{\
test::reportFailure( testConcatStr2( x, " failed:" ) + exc.getText(), __FUNCTION__, __LINE__, testCounts );\
test::reportFailure( exc.getText(), exc.getFunction(), exc.getLine(), testCounts );\
}\
catch ( ... )\
{\
test::reportFailure( testConcatStr2( x, " failed." ), __FUNCTION__, __LINE__, testCounts );\
test::reportFailure( "Unknown unhandled exception.", __FUNCTION__, __LINE__, testCounts );\
}

#define astCheck( x )\
Expand All @@ -1000,13 +999,13 @@ namespace test
testCounts.incTest();\
if ( !( ( x ) == ( y ) ) )\
{\
throw test::Exception{ "\n LHS: " + toString( x ) + "\n RHS: " + toString( y ) };\
throw test::Exception{ "\n LHS: " + toString( x ) + "\n RHS: " + toString( y ), __FUNCTION__, __LINE__ };\
}\
testCounts.flushErrors();\
}\
catch ( test::Exception & exc )\
{\
test::reportFailure( astTestConcatStr4( x, " == ", y, " failed:" ) + exc.getText(), __FUNCTION__, __LINE__, testCounts );\
test::reportFailure( astTestConcatStr4( x, " == ", y, " failed:" ) + exc.getText(), exc.getFunction(), exc.getLine(), testCounts );\
}\
catch ( ... )\
{\
Expand All @@ -1019,13 +1018,13 @@ namespace test
testCounts.incTest();\
if ( ( x ) == ( y ) )\
{\
throw test::Exception{ "\n LHS: " + toString( x ) + "\n RHS: " + toString( y ) };\
throw test::Exception{ "\n LHS: " + toString( x ) + "\n RHS: " + toString( y ), __FUNCTION__, __LINE__ };\
}\
testCounts.flushErrors();\
}\
catch ( test::Exception & exc )\
{\
test::reportFailure( astTestConcatStr4( x, " != ", y, " failed:" ) + exc.getText(), __FUNCTION__, __LINE__, testCounts );\
test::reportFailure( astTestConcatStr4( x, " != ", y, " failed:" ) + exc.getText(), exc.getFunction(), exc.getLine(), testCounts );\
}\
catch ( ... )\
{\
Expand Down Expand Up @@ -1077,13 +1076,13 @@ namespace test
testCounts.incTest();\
if ( !( x ) )\
{\
throw test::Exception{ "\n Value: " + toString( x ) };\
throw test::Exception{ "\n Value: " + toString( x ), __FUNCTION__, __LINE__ };\
}\
testCounts.flushErrors();\
}\
catch ( test::Exception & exc )\
{\
test::reportFailure( testConcatStr2( x, " failed:" ) + exc.getText(), f, l, __FUNCTION__, __LINE__, testCounts );\
test::reportFailure( testConcatStr2( x, " failed:" ) + exc.getText(), f, l, exc.getFunction(), exc.getLine(), testCounts );\
}\
catch ( ... )\
{\
Expand Down Expand Up @@ -1111,13 +1110,13 @@ namespace test
testCounts.incTest();\
if ( !( ( x ) == ( y ) ) )\
{\
throw test::Exception{ "\n LHS: " + toString( x ) + "\n RHS: " + toString( y ) };\
throw test::Exception{ "\n LHS: " + toString( x ) + "\n RHS: " + toString( y ), __FUNCTION__, __LINE__ };\
}\
testCounts.flushErrors();\
}\
catch ( test::Exception & exc )\
{\
test::reportFailure( astTestConcatStr4( x, " == ", y, " failed:" ) + exc.getText(), f, l, __FUNCTION__, __LINE__, testCounts );\
test::reportFailure( astTestConcatStr4( x, " == ", y, " failed:" ) + exc.getText(), f, l, exc.getFunction(), exc.getLine(), testCounts );\
}\
catch ( ... )\
{\
Expand All @@ -1130,13 +1129,13 @@ namespace test
testCounts.incTest();\
if ( ( x ) == ( y ) )\
{\
throw test::Exception{ "\n LHS: " + toString( x ) + "\n RHS: " + toString( y ) };\
throw test::Exception{ "\n LHS: " + toString( x ) + "\n RHS: " + toString( y ), __FUNCTION__, __LINE__ };\
}\
testCounts.flushErrors();\
}\
catch ( test::Exception & exc )\
{\
test::reportFailure( astTestConcatStr4( x, " != ", y, " failed:" ) + exc.getText(), f, l, __FUNCTION__, __LINE__, testCounts );\
test::reportFailure( astTestConcatStr4( x, " != ", y, " failed:" ) + exc.getText(), f, l, exc.getFunction(), exc.getLine(), testCounts );\
}\
catch ( ... )\
{\
Expand Down
66 changes: 66 additions & 0 deletions test/ShaderWriter/TestWriterControlStatements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,70 @@ namespace
, testCounts, CurrentCompilers );
astTestEnd();
}

void testAnonymousScope( test::sdw_test::TestCounts & testCounts )
{
astTestBegin( "testAnonymousScope" );
sdw::ShaderArray shaders;
{
sdw::ComputeWriter writer{ &testCounts.allocator };
auto i = writer.declSharedVariable< sdw::UInt >( "i" );
writer.implementMain( 32u, [&]( sdw::ComputeIn in )
{
if (auto scope = makeScope( writer ) )
{
i = in.globalInvocationID.x();
}
} );
test::writeShader( writer
, testCounts, CurrentCompilers );
shaders.emplace_back( std::move( writer.getShader() ) );
}
test::validateShaders( shaders
, testCounts, CurrentCompilers );
astTestEnd();
}

void testNestedAnonymousScopes( test::sdw_test::TestCounts & testCounts )
{
astTestBegin( "testNestedAnonymousScopes" );
sdw::ShaderArray shaders;
{
sdw::ComputeWriter writer{ &testCounts.allocator };
auto i = writer.declSharedVariable< sdw::UInt >( "i" );
auto j = writer.declSharedVariable< sdw::UInt >( "j" );
auto k = writer.declSharedVariable< sdw::UInt >( "k" );
auto l = writer.declSharedVariable< sdw::UInt >( "l" );
writer.implementMain( 32u, [&]( sdw::ComputeIn in )
{
if (auto scope1 = makeScope( writer ) )
{
l = in.localInvocationIndex;
if ( auto scope2 = makeScope( writer ) )
{
k = in.globalInvocationID.z();
if ( auto scope3 = makeScope( writer ) )
{
j = in.globalInvocationID.y();
if ( auto scope4 = makeScope( writer ) )
{
i = in.globalInvocationID.x();
}
j += in.globalInvocationID.y();
}
k += in.globalInvocationID.z();
}
l += in.localInvocationIndex;
}
} );
test::writeShader( writer
, testCounts, CurrentCompilers );
shaders.emplace_back( std::move( writer.getShader() ) );
}
test::validateShaders( shaders
, testCounts, CurrentCompilers );
astTestEnd();
}
}

sdwTestSuiteMain( TestWriterControlStatements )
Expand Down Expand Up @@ -884,6 +948,8 @@ sdwTestSuiteMain( TestWriterControlStatements )
testConstSwitch0( testCounts );
testConstSwitch1( testCounts );
testConstSwitchDefault( testCounts );
testAnonymousScope( testCounts );
testNestedAnonymousScopes( testCounts );
sdwTestSuiteEnd();
}

Expand Down
2 changes: 1 addition & 1 deletion test/ShaderWriter/TestWriterIncrement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ namespace
}
ROF;
astCheckEqual( writer.getBuilder().getContainer()->size(), 1u );
astBeginRequire( writer.getBuilder().getContainer()->back()->getKind() == stmt::Kind::eContainer );
astBeginRequire( writer.getBuilder().getContainer()->back()->getKind() == stmt::Kind::eCompound );
astCheckEqual( static_cast< stmt::Container const & >( *writer.getBuilder().getContainer()->back() ).size(), 1u );
astBeginRequire( static_cast< stmt::Container const & >( *writer.getBuilder().getContainer()->back() ).back()->getKind() == stmt::Kind::eFor );
astCheckEqual( static_cast< stmt::For const & >( *static_cast< stmt::Container const & >( *writer.getBuilder().getContainer()->back() ).back() ).size(), 1u );
Expand Down

0 comments on commit 48fe760

Please sign in to comment.