|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- config utilities for yml file.
- """
- import os
- import yaml
-
- class AttrDict(dict):
- """Dict as attribute trick."""
- def __init__(self, *args, **kwargs):
- super(AttrDict, self).__init__(*args, **kwargs)
- self.__dict__ = self
- for key in self.__dict__:
- value = self.__dict__[key]
- if isinstance(value, dict):
- self.__dict__[key] = AttrDict(value)
- elif isinstance(value, list):
- if isinstance(value[0], dict):
- self.__dict__[key] = [AttrDict(item) for item in value]
- else:
- self.__dict__[key] = value
-
- def yaml(self):
- """Convert object to yaml dict and return."""
- yaml_dict = {}
- for key in self.__dict__:
- value = self.__dict__[key]
- if isinstance(value, AttrDict):
- yaml_dict[key] = value.yaml()
- elif isinstance(value, list):
- if isinstance(value[0], AttrDict):
- new_l = []
- for item in value:
- new_l.append(item.yaml())
- yaml_dict[key] = new_l
- else:
- yaml_dict[key] = value
- else:
- yaml_dict[key] = value
- return yaml_dict
-
- def __repr__(self):
- """Print all variables."""
- ret_str = []
- for key in self.__dict__:
- value = self.__dict__[key]
- if isinstance(value, AttrDict):
- ret_str.append('{}:'.format(key))
- child_ret_str = value.__repr__().split('\n')
- for item in child_ret_str:
- ret_str.append(' ' + item)
- elif isinstance(value, list):
- if isinstance(value[0], AttrDict):
- ret_str.append('{}:'.format(key))
- for item in value:
- # treat as AttrDict above
- child_ret_str = item.__repr__().split('\n')
- for item0 in child_ret_str:
- ret_str.append(' ' + item0)
- else:
- ret_str.append('{}: {}'.format(key, value))
- else:
- ret_str.append('{}: {}'.format(key, value))
- return '\n'.join(ret_str)
-
-
- class Config(AttrDict):
- """Config with yaml file."""
- def __init__(self, filename=None):
- assert os.path.exists(filename), 'File {} not exist.'.format(filename)
- try:
- with open(filename, 'r') as f:
- cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
- except EnvironmentError:
- print('Please check the file with name of "%s"', filename)
- super(Config, self).__init__(cfg_dict)
-
-
- if os.getcwd().find('/cache/user-job-dir') == -1:
- if os.path.exists('./src/autoslim_cfg.yml'):
- FLAGS = Config('./src/autoslim_cfg.yml')
- else:
- FLAGS = Config('../src/autoslim_cfg.yml')
- else:
- FLAGS = Config('/cache/user-job-dir/model_files/src/autoslim_cfg.yml')
|