|
- # Copyright (c) Microsoft Corporation.
- # Licensed under the MIT license.
-
- from .utils import to_json
- from .runtime.env_vars import trial_env_vars
- from .runtime import platform
-
-
- __all__ = [
- 'get_next_parameter',
- 'get_current_parameter',
- 'report_intermediate_result',
- 'report_final_result',
- 'get_experiment_id',
- 'get_trial_id',
- 'get_sequence_id'
- ]
-
-
- _params = None
- _experiment_id = platform.get_experiment_id()
- _trial_id = platform.get_trial_id()
- _sequence_id = platform.get_sequence_id()
-
-
- def get_next_parameter():
- """
- Get the hyper paremeters generated by tuner. For a multiphase experiment, it returns a new group of hyper
- parameters at each call of get_next_parameter. For a non-multiphase (multiPhase is not configured or set to False)
- experiment, it returns hyper parameters only on the first call for each trial job, it returns None since second call.
- This API should be called only once in each trial job of an experiment which is not specified as multiphase.
-
- Returns
- -------
- dict
- A dict object contains the hyper parameters generated by tuner, the keys of the dict are defined in
- search space. Returns None if no more hyper parameters can be generated by tuner.
- """
- global _params
- _params = platform.get_next_parameter()
- if _params is None:
- return None
- return _params['parameters']
-
- def get_current_parameter(tag=None):
- """
- Get current hyper parameters generated by tuner. It returns the same group of hyper parameters as the last
- call of get_next_parameter returns.
-
- Parameters
- ----------
- tag: str
- hyper parameter key
- """
- global _params
- if _params is None:
- return None
- if tag is None:
- return _params['parameters']
- return _params['parameters'][tag]
-
- def get_experiment_id():
- """
- Get experiment ID.
-
- Returns
- -------
- str
- Identifier of current experiment
- """
- return _experiment_id
-
- def get_trial_id():
- """
- Get trial job ID which is string identifier of a trial job, for example 'MoXrp'. In one experiment, each trial
- job has an unique string ID.
-
- Returns
- -------
- str
- Identifier of current trial job which is calling this API.
- """
- return _trial_id
-
- def get_sequence_id():
- """
- Get trial job sequence nubmer. A sequence number is an integer value assigned to each trial job base on the
- order they are submitted, incremental starting from 0. In one experiment, both trial job ID and sequence number
- are unique for each trial job, they are of different data types.
-
- Returns
- -------
- int
- Sequence number of current trial job which is calling this API.
- """
- return _sequence_id
-
- _intermediate_seq = 0
-
-
- def overwrite_intermediate_seq(value):
- """
- Overwrite intermediate sequence value.
-
- Parameters
- ----------
- value:
- int
- """
- assert isinstance(value, int)
- global _intermediate_seq
- _intermediate_seq = value
-
-
- def report_intermediate_result(metric):
- """
- Reports intermediate result to NNI.
-
- Parameters
- ----------
- metric:
- serializable object.
- """
- global _intermediate_seq
- assert _params or trial_env_vars.NNI_PLATFORM is None, \
- 'xcom.get_next_parameter() needs to be called before report_intermediate_result'
- metric = to_json({
- 'parameter_id': _params['parameter_id'] if _params else None,
- 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
- 'type': 'PERIODICAL',
- 'sequence': _intermediate_seq,
- 'value': to_json(metric)
- })
- _intermediate_seq += 1
- platform.send_metric(metric)
-
- def report_final_result(metric):
- """
- Reports final result to NNI.
-
- Parameters
- ----------
- metric: serializable object
- Usually (for built-in tuners to work), it should be a number, or
- a dict with key "default" (a number), and any other extra keys.
- """
- assert _params or trial_env_vars.NNI_PLATFORM is None, \
- 'xcom.get_next_parameter() needs to be called before report_final_result'
- metric = to_json({
- 'parameter_id': _params['parameter_id'] if _params else None,
- 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
- 'type': 'FINAL',
- 'sequence': 0,
- 'value': to_json(metric)
- })
- platform.send_metric(metric)
|