Skip to content

Commit

Permalink
Fixed the FOBS lazy loading issue (#3121)
Browse files Browse the repository at this point in the history
* Fixed the FOBS lazy loading issue

* Fixed format

* Added check for builtins without module name
  • Loading branch information
nvidianz authored Jan 4, 2025
1 parent 62e2441 commit fe41433
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
16 changes: 7 additions & 9 deletions nvflare/fuel/utils/fobs/fobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import importlib
import inspect
import logging
Expand Down Expand Up @@ -66,15 +67,12 @@ def _get_type_name(cls: Type) -> str:

def _load_class(type_name: str):
try:
parts = type_name.split(".")
if len(parts) == 1:
parts = ["builtins", type_name]

mod = __import__(parts[0])
for comp in parts[1:]:
mod = getattr(mod, comp)

return mod
if "." in type_name:
module_name, class_name = type_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
else:
return getattr(builtins, type_name)
except Exception as ex:
raise TypeError(f"Can't load class {type_name}: {ex}")

Expand Down
9 changes: 3 additions & 6 deletions tests/unit_test/fuel/utils/fobs/fobs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ class TestFobs:
NUMBER = 123456
FLOAT = 123.456
NAME = "FOBS Test"
SET = {4, 5, 6}
NOW = datetime.now()

test_data = {
"str": "Test string",
"number": NUMBER,
"float": FLOAT,
"list": [7, 8, 9],
"set": {4, 5, 6},
"set": SET,
"tuple": ("abc", "xyz"),
"time": NOW,
}
Expand All @@ -44,11 +45,7 @@ def test_builtin(self):
buf = fobs.dumps(TestFobs.test_data)
data = fobs.loads(buf)
assert data["number"] == TestFobs.NUMBER

def test_aliases(self):
buf = fobs.dumps(TestFobs.test_data)
data = fobs.loads(buf)
assert data["number"] == TestFobs.NUMBER
assert data["set"] == TestFobs.SET

def test_unsupported_classes(self):
with pytest.raises(TypeError):
Expand Down

0 comments on commit fe41433

Please sign in to comment.