Skip to content

Commit

Permalink
Handle Class Constructor with arguments using __init__ method (#2775)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanay-man authored Jul 16, 2024
1 parent f4c4b94 commit 9e8da7d
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 61 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,7 @@ RUN(NAME lambda_01 LABELS cpython llvm llvm_jit)

RUN(NAME c_mangling LABELS cpython llvm llvm_jit c)
RUN(NAME class_01 LABELS cpython llvm llvm_jit)
RUN(NAME class_02 LABELS cpython llvm llvm_jit)

# callback_04 is to test emulation. So just run with cpython
RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython)
Expand Down
44 changes: 44 additions & 0 deletions integration_tests/class_02.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from lpython import i32
class Character:
def __init__(self:"Character", name:str, health:i32, attack_power:i32):
self.name :str = name
self.health :i32 = health
self.attack_power : i32 = attack_power
self.is_immortal : bool = False

def attack(self:"Character", other:"Character") -> str:
other.health -= self.attack_power
return self.name+" attacks "+ other.name+" for "+str(self.attack_power)+" damage."

def is_alive(self:"Character")->bool:
if self.is_immortal:
return True
else:
return self.health > 0

def main():
hero : Character = Character("Hero", 10, 20)
monster : Character = Character("Monster", 50, 15)
print(hero.attack(monster))
print(monster.health)
assert monster.health == 30
print(monster.is_alive())
assert monster.is_alive() == True
print("Hero gains temporary immortality")
hero.is_immortal = True
print(monster.attack(hero))
print(hero.health)
assert hero. health == -5
print(hero.is_alive())
assert hero.is_alive() == True
print("Hero's immortality runs out")
hero.is_immortal = False
print(hero.is_alive())
assert hero.is_alive() == False
print("Restarting")
hero = Character("Hero", 10, 20)
print(hero.is_alive())
assert hero.is_alive() == True

main()

30 changes: 30 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3087,6 +3087,30 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}

void instantiate_methods(const ASR::Struct_t &x) {
SymbolTable *current_scope_copy = current_scope;
current_scope = x.m_symtab;
for ( auto &item : x.m_symtab->get_scope() ) {
if ( is_a<ASR::Function_t>(*item.second) ) {
ASR::Function_t *v = down_cast<ASR::Function_t>(item.second);
instantiate_function(*v);
}
}
current_scope = current_scope_copy;
}

void visit_methods (const ASR::Struct_t &x) {
SymbolTable *current_scope_copy = current_scope;
current_scope = x.m_symtab;
for ( auto &item : x.m_symtab->get_scope() ) {
if ( is_a<ASR::Function_t>(*item.second) ) {
ASR::Function_t *v = down_cast<ASR::Function_t>(item.second);
visit_Function(*v);
}
}
current_scope = current_scope_copy;
}

void start_module_init_function_prototype(const ASR::Module_t &x) {
uint32_t h = get_hash((ASR::asr_t*)&x);
llvm::FunctionType *function_type = llvm::FunctionType::get(
Expand Down Expand Up @@ -3128,6 +3152,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
} else if (is_a<ASR::EnumType_t>(*item.second)) {
ASR::EnumType_t *et = down_cast<ASR::EnumType_t>(item.second);
visit_EnumType(*et);
} else if (is_a<ASR::Struct_t>(*item.second)) {
ASR::Struct_t *st = down_cast<ASR::Struct_t>(item.second);
instantiate_methods(*st);
}
}
finish_module_init_function_prototype(x);
Expand Down Expand Up @@ -4179,6 +4206,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
if (is_a<ASR::Function_t>(*item.second)) {
ASR::Function_t *s = ASR::down_cast<ASR::Function_t>(item.second);
visit_Function(*s);
} else if ( is_a<ASR::Struct_t>(*item.second) ) {
ASR::Struct_t *st = down_cast<ASR::Struct_t>(item.second);
visit_methods(*st);
}
}
}
Expand Down
Loading

0 comments on commit 9e8da7d

Please sign in to comment.