From f2202262db7bf9a7b8df4ac58fee7eac5d198025 Mon Sep 17 00:00:00 2001 From: Ian Goodfellow Date: Wed, 10 Jul 2024 11:35:01 -0700 Subject: [PATCH] Make SourceModels constructor responsible for building the Sources PiperOrigin-RevId: 651096431 --- torax/sources/source_models.py | 54 ++++++++++++++++------------ torax/sources/tests/source_models.py | 20 ++++++----- 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index 8aafb174..7c08abe4 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -652,32 +652,46 @@ class SourceModels: def __init__( self, - sources: dict[str, source_lib.Source] | None = None, + source_builders: ( + dict[str, source_lib.SourceBuilderProtocol] | None + ) = None, ): """Constructs a collection of sources. + The constructor should only be called by SourceModelsBuilder. + This class defines which sources are available in a TORAX simulation run. Users can configure whether each source is actually on and what kind of profile it produces by changing its runtime configuration (see runtime_params_lib.py). Args: - sources: Mapping of source model names to the Source objects. The names - (i.e. the keys of this dictionary) also define the keys in the output - SourceProfiles which are computed from this SourceModels object. NOTE - - Some sources are "special-case": bootstrap current, external current, - and Qei. SourceModels will always instantiate default objects for these - types of sources unless they are provided by this `sources` argument. - Also, their default names are reserved, meaning the input dictionary - `sources` should not have the keys 'j_bootstrap', 'jext', or - 'qei_source' unless those sources are one of these "special-case" - sources. + source_builders: Mapping of source model names to builders of the Source + objects. The names (i.e. the keys of this dictionary) also define the + keys in the output SourceProfiles which are computed from this + SourceModels object. NOTE - Some sources are "special-case": bootstrap + current, external current, and Qei. SourceModels will always instantiate + default objects for these types of sources unless they are provided by + this `sources` argument. Also, their default names are reserved, meaning + the input dictionary `sources` should not have the keys 'j_bootstrap', + 'jext', or 'qei_source' unless those sources are one of these + "special-case" sources. Raises: ValueError if there is a naming collision with the reserved names as described above. """ - sources = sources or {} + + source_builders = source_builders or {} + + # Begin initial construction with sources that don't link back to the + # SourceModels + sources = { + name: builder() + for name, builder in source_builders.items() + if not builder.links_back + } + # Some sources are accessed for specific use cases, so we extract those # ones and expose them directly. self._j_bootstrap = None @@ -741,6 +755,11 @@ def __init__( else: self.add_source(source_name, source) + # Now add the sources that link back + for name, builder in source_builders.items(): + if builder.links_back: + self.add_source(name, builder(self)) + def add_source( self, source_name: str, @@ -931,16 +950,7 @@ def __init__( def __call__(self) -> SourceModels: - unlinked_sources = { - name: builder() - for name, builder in self.source_builders.items() - if not builder.links_back - } - initial_model = SourceModels(unlinked_sources) - for name, builder in self.source_builders.items(): - if builder.links_back: - initial_model.add_source(name, builder(initial_model)) - return initial_model + return SourceModels(self.source_builders) @property def runtime_params(self) -> dict[str, runtime_params_lib.RuntimeParams]: diff --git a/torax/sources/tests/source_models.py b/torax/sources/tests/source_models.py index de946f02..d3b64005 100644 --- a/torax/sources/tests/source_models.py +++ b/torax/sources/tests/source_models.py @@ -215,22 +215,22 @@ def test_cannot_add_multiple_special_case_sources(self): with self.assertRaises(ValueError): source_models_lib.SourceModels( dict( - j_bootstrap=bootstrap_current_source.BootstrapCurrentSource(), - j_bootstrap2=bootstrap_current_source.BootstrapCurrentSource(), + j_bootstrap=bootstrap_current_source.BootstrapCurrentSourceBuilder(), + j_bootstrap2=bootstrap_current_source.BootstrapCurrentSourceBuilder(), ) ) with self.assertRaises(ValueError): source_models_lib.SourceModels( dict( - qei=qei_source.QeiSource(), - qei2=qei_source.QeiSource(), + qei=qei_source.QeiSourceBuilder(), + qei2=qei_source.QeiSourceBuilder(), ) ) with self.assertRaises(ValueError): source_models_lib.SourceModels( dict( - external_current=external_current_source.ExternalCurrentSource(), - external_current2=external_current_source.ExternalCurrentSource(), + external_current=external_current_source.ExternalCurrentSourceBuilder(), + external_current2=external_current_source.ExternalCurrentSourceBuilder(), ) ) source_models = source_models_lib.SourceModels() @@ -248,16 +248,18 @@ def test_cannot_add_multiple_special_case_sources(self): def test_cannot_add_multiple_sources_with_same_name(self): """Tests that SourceModels cannot add multiple sources with same name.""" source_name = 'foo' - foo_source = source_lib.Source( + foo_source_builder = source_lib.SourceBuilder( affected_core_profiles=(source_lib.AffectedCoreProfile.TEMP_EL,), supported_modes=(runtime_params_lib.Mode.ZERO,), ) source_models = source_models_lib.SourceModels( - sources={source_name: foo_source}, + source_builders={source_name: foo_source_builder}, ) + # It's built once by the SourceModels constructor + rebuilt = foo_source_builder() # Cannot add another source with that name again. with self.assertRaises(ValueError): - source_models.add_source(source_name, foo_source) + source_models.add_source(source_name, rebuilt) if __name__ == '__main__':