Skip to content

Commit

Permalink
Merge pull request #41679 from gabilang/fix-default-worker-error-return
Browse files Browse the repository at this point in the history
Fix default worker crash with error return
  • Loading branch information
warunalakshitha authored Nov 23, 2023
2 parents ad0ab92 + 3b4940b commit 79694aa
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ void genPanic(BIRTerminator.Panic panicTerm) {
}

public void generateTryCatch(BIRNode.BIRFunction func, String funcName, BIRNode.BIRBasicBlock currentBB,
JvmTerminatorGen termGen, LabelGenerator labelGen, int invocationVarIndex) {
JvmTerminatorGen termGen, LabelGenerator labelGen, int invocationVarIndex,
int localVarOffset) {

BIRNode.BIRErrorEntry currentEE = findErrorEntry(func.errorTable, currentBB);
if (currentEE == null) {
Expand Down Expand Up @@ -108,7 +109,7 @@ public void generateTryCatch(BIRNode.BIRFunction func, String funcName, BIRNode.
this.mv.visitMethodInsn(INVOKESTATIC, ERROR_UTILS, CREATE_INTEROP_ERROR_METHOD,
CREATE_ERROR_FROM_THROWABLE, false);
jvmInstructionGen.generateVarStore(this.mv, retVarDcl, retIndex);
termGen.genReturnTerm(retIndex, func, invocationVarIndex);
termGen.genReturnTerm(retIndex, func, invocationVarIndex, localVarOffset);
this.mv.visitJumpInsn(GOTO, jumpLabel);
}
if (!exeptionExist) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ public void genTerminator(BIRTerminator terminator, String moduleClassName, BIRN
this.genBranchTerm((BIRTerminator.Branch) terminator, funcName);
return;
case RETURN:
this.genReturnTerm(returnVarRefIndex, func, invocationVarIndex);
this.genReturnTerm(returnVarRefIndex, func, invocationVarIndex, localVarOffset);
return;
case PANIC:
this.errorGen.genPanic((BIRTerminator.Panic) terminator);
Expand Down Expand Up @@ -448,9 +448,9 @@ private void genUnlockTerm(BIRTerminator.Unlock unlockIns, String funcName) {
}

private void handleErrorRetInUnion(int returnVarRefIndex, List<BIRNode.ChannelDetails> channels, BUnionType bType,
int invocationVarIndex) {
int invocationVarIndex, int localVarOffset) {

if (channels.size() == 0) {
if (channels.isEmpty()) {
return;
}

Expand All @@ -465,7 +465,7 @@ private void handleErrorRetInUnion(int returnVarRefIndex, List<BIRNode.ChannelDe

if (errorIncluded) {
this.mv.visitVarInsn(ALOAD, returnVarRefIndex);
this.mv.visitVarInsn(ALOAD, 0);
this.mv.visitVarInsn(ALOAD, localVarOffset);
JvmCodeGenUtil.loadChannelDetails(this.mv, channels, invocationVarIndex);
this.mv.visitMethodInsn(INVOKESTATIC, WORKER_UTILS, "handleWorkerError",
HANDLE_WORKER_ERROR, false);
Expand All @@ -474,7 +474,7 @@ private void handleErrorRetInUnion(int returnVarRefIndex, List<BIRNode.ChannelDe

private void notifyChannels(List<BIRNode.ChannelDetails> channels, int retIndex, int invocationVarIndex) {

if (channels.size() == 0) {
if (channels.isEmpty()) {
return;
}

Expand Down Expand Up @@ -1324,13 +1324,14 @@ private void genResourcePathArgs(List<BIROperand> pathArgs) {
mv.visitVarInsn(ALOAD, bundledArrayIndex);
}

public void genReturnTerm(int returnVarRefIndex, BIRNode.BIRFunction func, int invocationVarIndex) {
public void genReturnTerm(int returnVarRefIndex, BIRNode.BIRFunction func, int invocationVarIndex,
int localVarOffset) {
BType bType = unifier.build(func.type.retType);
generateReturnTermFromType(returnVarRefIndex, bType, func, invocationVarIndex);
generateReturnTermFromType(returnVarRefIndex, bType, func, invocationVarIndex, localVarOffset);
}

private void generateReturnTermFromType(int returnVarRefIndex, BType bType, BIRNode.BIRFunction func,
int invocationVarIndex) {
int invocationVarIndex, int localVarOffset) {
bType = JvmCodeGenUtil.getImpliedType(bType);
if (TypeTags.isIntegerTypeTag(bType.tag)) {
this.mv.visitVarInsn(LLOAD, returnVarRefIndex);
Expand Down Expand Up @@ -1377,7 +1378,7 @@ private void generateReturnTermFromType(int returnVarRefIndex, BType bType, BIRN
break;
case TypeTags.UNION:
this.handleErrorRetInUnion(returnVarRefIndex, Arrays.asList(func.workerChannels),
(BUnionType) bType, invocationVarIndex);
(BUnionType) bType, invocationVarIndex, localVarOffset);
this.mv.visitVarInsn(ALOAD, returnVarRefIndex);
this.mv.visitInsn(ARETURN);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ static void genJFieldForInteropField(JFieldBIRFunction birFunc, ClassWriter clas
Label retLabel = labelGen.getLabel("return_lable");
mv.visitLabel(retLabel);
mv.visitLineNumber(birFunc.pos.lineRange().endLine().line() + 1, retLabel);
termGen.genReturnTerm(returnVarRefIndex, birFunc, -1);
termGen.genReturnTerm(returnVarRefIndex, birFunc, -1, 0);
JvmCodeGenUtil.visitMaxStackForMethod(mv, birFunc.name.value, birFunc.javaField.getDeclaringClassName());
mv.visitEnd();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ public void genJMethodForBFunc(BIRFunction func, ClassWriter cw, BIRPackage modu

Label methodEndLabel = new Label();
mv.visitLabel(methodEndLabel);
termGen.genReturnTerm(returnVarRefIndex, func, invocationVarIndex);
termGen.genReturnTerm(returnVarRefIndex, func, invocationVarIndex, localVarOffset);

// Create Local Variable Table
createLocalVariableTable(func, indexMap, localVarOffset, mv, methodStartLabel, labelGen, methodEndLabel,
Expand Down Expand Up @@ -639,7 +639,7 @@ void generateBasicBlocks(MethodVisitor mv, LabelGenerator labelGen, JvmErrorGen
lastScope = JvmCodeGenUtil
.getLastScopeFromTerminator(mv, bb, funcName, labelGen, lastScope, visitedScopesSet);

errorGen.generateTryCatch(func, funcName, bb, termGen, labelGen, invocationVarIndex);
errorGen.generateTryCatch(func, funcName, bb, termGen, labelGen, invocationVarIndex, localVarOffset);

String yieldStatus = getYieldStatusByTerminator(terminator);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,11 @@ public void testErrorReturnFunction() {
CompileResult compileResult = BCompileUtil.compile("test-src/services/error_return_function.bal");
BRunUtil.invoke(compileResult, "testErrorFunction");
}

@Test(description = "Tests invoking a function with default worker returning an error union value in a service")
public void testErrorReturnFunctionWithDistinctListenerArg() {
CompileResult compileResult = BCompileUtil.compile(
"test-src/services/error_union_return_with_default_worker.bal");
BRunUtil.invoke(compileResult, "testErrorUnionWithDefaultWorkerFunction");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2023, WSO2 LLC. (https://www.wso2.com).
//
// WSO2 LLC. licenses this file to you under the Apache License,
// Version 2.0 (the "License"); you may not use this file except
// in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

import ballerina/test;

public function testErrorUnionWithDefaultWorkerFunction() {
Class c = new(10);
test:assertEquals(10, checkpanic c.getId());
}

public class Class {
private int id;

public function init(int id) {
self.id = id;
}

public function getId() returns int|error {
worker w1 returns error? {
self.id -> function;
}
return <- w1;
}
}

0 comments on commit 79694aa

Please sign in to comment.