Skip to content

Commit

Permalink
forst_PR
Browse files Browse the repository at this point in the history
  • Loading branch information
cpz2024 committed Dec 27, 2024
1 parent fe6c738 commit 65ae719
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 24 deletions.
5 changes: 2 additions & 3 deletions sml/feature_selection/tests/chi2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
import unittest

import numpy as np
from sklearn.datasets import load_iris
from sklearn.feature_selection import chi2 as chi2_sklearn

import spu.spu_pb2 as spu_pb2
import spu.utils.simulation as spsim
from sklearn.datasets import load_iris
from sklearn.feature_selection import chi2 as chi2_sklearn

sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
from sml.feature_selection.univariate_selection import chi2
Expand Down
6 changes: 3 additions & 3 deletions sml/metrics/regression/regression_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@

import sml.utils.emulation as emulation
from sml.metrics.regression.regression import (
d2_tweedie_score,
explained_variance_score,
mean_gamma_deviance,
mean_poisson_deviance,
mean_squared_error,
mean_poisson_deviance,
mean_gamma_deviance,
d2_tweedie_score,
)


Expand Down
6 changes: 3 additions & 3 deletions sml/metrics/regression/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))

from sml.metrics.regression.regression import (
d2_tweedie_score,
explained_variance_score,
mean_gamma_deviance,
mean_poisson_deviance,
mean_squared_error,
mean_poisson_deviance,
mean_gamma_deviance,
d2_tweedie_score,
)


Expand Down
3 changes: 1 addition & 2 deletions spu/tests/pir_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
import unittest

import multiprocess
from google.protobuf import json_format

import spu.libspu.link as link
import spu.psi as psi
from google.protobuf import json_format
from spu.tests.utils import create_link_desc


Expand Down
19 changes: 9 additions & 10 deletions spu/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,27 @@
# limitations under the License.

from .distributed_impl import ( # type: ignore
PYU,
RPC,
SAMPLE_DEVICES_DEF,
SAMPLE_NODES_DEF,
PYU,
SPU,
Framework,
current,
init,
device,
dtype_spu_to_np,
get,
init,
load,
save,
current,
set_framework,
SAMPLE_NODES_DEF,
SAMPLE_DEVICES_DEF,
dtype_spu_to_np,
shape_spu_to_np,
save,
load,
)


def main():
import argparse
import json

from spu.utils.polyfill import Process

parser = argparse.ArgumentParser(description='SPU node service.')
Expand Down Expand Up @@ -71,4 +70,4 @@ def main():


if __name__ == '__main__':
main()
main()
7 changes: 4 additions & 3 deletions spu/utils/distributed_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import numpy as np
from termcolor import colored


from .. import api as spu_api
from .. import libspu # type: ignore
from .. import spu_pb2
Expand Down Expand Up @@ -727,7 +728,7 @@ def mock_parameters(obj: Union[SPU.Object, np.ndarray]):
except ImportError:
import jax.linear_util as lu # fallback
from jax._src import api_util as japi_util
from jax.tree_util import tree_flatten, tree_map
from jax.tree_util import tree_map, tree_flatten

mock_args, mock_kwargs = tree_map(mock_parameters, (args, kwargs))

Expand Down Expand Up @@ -956,7 +957,7 @@ def mock_parameters(obj: Union[SPU.Object, np.ndarray]):
fn, torch.nn.Module
), "currently only torch.nn.Module is supported"

from jax.tree_util import tree_flatten, tree_map
from jax.tree_util import tree_map, tree_flatten

mock_args, mock_kwargs = tree_map(mock_parameters, (args, kwargs))

Expand Down Expand Up @@ -1379,4 +1380,4 @@ def SPU2SPU(to: SPU, obj: SPU.Object):
},
"P1": {"kind": "PYU", "config": {"node_id": "node:0"}},
"P2": {"kind": "PYU", "config": {"node_id": "node:1"}},
}
}

0 comments on commit 65ae719

Please sign in to comment.