# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import collections
import json
import os
import re
from contextlib import contextmanager
from pathlib import Path

from git import Repo


# This script is intended to be run from the root of the repo but you can adapt this constant if you need to.
PATH_TO_TRANFORMERS = "."

# A temporary way to trigger all pipeline tests contained in model test files after PR #21516
all_model_test_files = [str(x) for x in Path("tests/models/").glob("**/**/test_modeling_*.py")]

all_pipeline_test_files = [str(x) for x in Path("tests/pipelines/").glob("**/test_pipelines_*.py")]


@contextmanager
def checkout_commit(repo, commit_id):
    """
    Context manager that checks out a commit in the repo.
    """
    current_head = repo.head.commit if repo.head.is_detached else repo.head.ref

    try:
        repo.git.checkout(commit_id)
        yield

    finally:
        repo.git.checkout(current_head)


def clean_code(content):
    """
    Remove docstrings, empty line or comments from `content`.
    """
    # fmt: off
    # Remove docstrings by splitting on triple " then triple ':
    splits = content.split('\"\"\"')
    content = "".join(splits[::2])
    splits = content.split("\'\'\'")
    # fmt: on
    content = "".join(splits[::2])

    # Remove empty lines and comments
    lines_to_keep = []
    for line in content.split("\n"):
        # remove anything that is after a # sign.
        line = re.sub("#.*$", "", line)
        if len(line) == 0 or line.isspace():
            continue
        lines_to_keep.append(line)
    return "\n".join(lines_to_keep)


def get_all_tests():
    """
    Return a list of paths to all test folders and files under `tests`. All paths are rooted at `tests`.

    - folders under `tests`: `tokenization`, `pipelines`, etc. The folder `models` is excluded.
    - folders under `tests/models`: `bert`, `gpt2`, etc.
    - test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc.
    """
    test_root_dir = os.path.join(PATH_TO_TRANFORMERS, "tests")

    # test folders/files directly under `tests` folder
    tests = os.listdir(test_root_dir)
    tests = sorted(filter(lambda x: os.path.isdir(x) or x.startswith("tests/test_"), [f"tests/{x}" for x in tests]))

    # model specific test folders
    model_tests_folders = os.listdir(os.path.join(test_root_dir, "models"))
    model_test_folders = sorted(filter(os.path.isdir, [f"tests/models/{x}" for x in model_tests_folders]))

    tests.remove("tests/models")
    tests = model_test_folders + tests

    return tests


def diff_is_docstring_only(repo, branching_point, filename):
    """
    Check if the diff is only in docstrings in a filename.
    """
    with checkout_commit(repo, branching_point):
        with open(filename, "r", encoding="utf-8") as f:
            old_content = f.read()

    with open(filename, "r", encoding="utf-8") as f:
        new_content = f.read()

    old_content_clean = clean_code(old_content)
    new_content_clean = clean_code(new_content)

    return old_content_clean == new_content_clean


def get_modified_python_files(diff_with_last_commit=False):
    """
    Return a list of python files that have been modified between:

    - the current head and the main branch if `diff_with_last_commit=False` (default)
    - the current head and its parent commit otherwise.
    """
    repo = Repo(PATH_TO_TRANFORMERS)

    if not diff_with_last_commit:
        print(f"main is at {repo.refs.main.commit}")
        print(f"Current head is at {repo.head.commit}")

        branching_commits = repo.merge_base(repo.refs.main, repo.head)
        for commit in branching_commits:
            print(f"Branching commit: {commit}")
        return get_diff(repo, repo.head.commit, branching_commits)
    else:
        print(f"main is at {repo.head.commit}")
        parent_commits = repo.head.commit.parents
        for commit in parent_commits:
            print(f"Parent commit: {commit}")
        return get_diff(repo, repo.head.commit, parent_commits)


def get_diff(repo, base_commit, commits):
    """
    Get's the diff between one or several commits and the head of the repository.
    """
    print("\n### DIFF ###\n")
    code_diff = []
    for commit in commits:
        for diff_obj in commit.diff(base_commit):
            # We always add new python files
            if diff_obj.change_type == "A" and diff_obj.b_path.endswith(".py"):
                code_diff.append(diff_obj.b_path)
            # We check that deleted python files won't break corresponding tests.
            elif diff_obj.change_type == "D" and diff_obj.a_path.endswith(".py"):
                code_diff.append(diff_obj.a_path)
            # Now for modified files
            elif diff_obj.change_type in ["M", "R"] and diff_obj.b_path.endswith(".py"):
                # In case of renames, we'll look at the tests using both the old and new name.
                if diff_obj.a_path != diff_obj.b_path:
                    code_diff.extend([diff_obj.a_path, diff_obj.b_path])
                else:
                    # Otherwise, we check modifications are in code and not docstrings.
                    if diff_is_docstring_only(repo, commit, diff_obj.b_path):
                        print(f"Ignoring diff in {diff_obj.b_path} as it only concerns docstrings or comments.")
                    else:
                        code_diff.append(diff_obj.a_path)

    return code_diff


def get_module_dependencies(module_fname):
    """
    Get the dependencies of a module.
    """
    with open(os.path.join(PATH_TO_TRANFORMERS, module_fname), "r", encoding="utf-8") as f:
        content = f.read()

    module_parts = module_fname.split(os.path.sep)
    imported_modules = []

    # Let's start with relative imports
    relative_imports = re.findall(r"from\s+(\.+\S+)\s+import\s+([^\n]+)\n", content)
    relative_imports = [mod for mod, imp in relative_imports if "# tests_ignore" not in imp]
    for imp in relative_imports:
        level = 0
        while imp.startswith("."):
            imp = imp[1:]
            level += 1

        if len(imp) > 0:
            dep_parts = module_parts[: len(module_parts) - level] + imp.split(".")
        else:
            dep_parts = module_parts[: len(module_parts) - level] + ["__init__.py"]
        imported_module = os.path.sep.join(dep_parts)
        # We ignore the main init import as it's only for the __version__ that it's done
        # and it would add everything as a dependency.
        if not imported_module.endswith("transformers/__init__.py"):
            imported_modules.append(imported_module)

    # Let's continue with direct imports
    # The import from the transformers module are ignored for the same reason we ignored the
    # main init before.
    direct_imports = re.findall(r"from\s+transformers\.(\S+)\s+import\s+([^\n]+)\n", content)
    direct_imports = [mod for mod, imp in direct_imports if "# tests_ignore" not in imp]
    for imp in direct_imports:
        import_parts = imp.split(".")
        dep_parts = ["src", "transformers"] + import_parts
        imported_modules.append(os.path.sep.join(dep_parts))

    # Now let's just check that we have proper module files, or append an init for submodules
    dependencies = []
    for imported_module in imported_modules:
        if os.path.isfile(os.path.join(PATH_TO_TRANFORMERS, f"{imported_module}.py")):
            dependencies.append(f"{imported_module}.py")
        elif os.path.isdir(os.path.join(PATH_TO_TRANFORMERS, imported_module)) and os.path.isfile(
            os.path.sep.join([PATH_TO_TRANFORMERS, imported_module, "__init__.py"])
        ):
            dependencies.append(os.path.sep.join([imported_module, "__init__.py"]))
    return dependencies


def get_test_dependencies(test_fname):
    """
    Get the dependencies of a test file.
    """
    with open(os.path.join(PATH_TO_TRANFORMERS, test_fname), "r", encoding="utf-8") as f:
        content = f.read()

    # Tests only have relative imports for other test files
    # TODO Sylvain: handle relative imports cleanly
    relative_imports = re.findall(r"from\s+(\.\S+)\s+import\s+([^\n]+)\n", content)
    relative_imports = [test for test, imp in relative_imports if "# tests_ignore" not in imp]

    def _convert_relative_import_to_file(relative_import):
        level = 0
        while relative_import.startswith("."):
            level += 1
            relative_import = relative_import[1:]

        directory = os.path.sep.join(test_fname.split(os.path.sep)[:-level])
        return os.path.join(directory, f"{relative_import.replace('.', os.path.sep)}.py")

    dependencies = [_convert_relative_import_to_file(relative_import) for relative_import in relative_imports]
    return [f for f in dependencies if os.path.isfile(os.path.join(PATH_TO_TRANFORMERS, f))]


def create_reverse_dependency_tree():
    """
    Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
    """
    modules = [
        str(f.relative_to(PATH_TO_TRANFORMERS))
        for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
    ]
    module_edges = [(d, m) for m in modules for d in get_module_dependencies(m)]

    tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")]
    test_edges = [(d, t) for t in tests for d in get_test_dependencies(t)]

    return module_edges + test_edges


def get_tree_starting_at(module, edges):
    """
    Returns the tree starting at a given module following all edges in the following format: [module, [list of edges
    starting at module], [list of edges starting at the preceding level], ...]
    """
    vertices_seen = [module]
    new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module]
    tree = [module]
    while len(new_edges) > 0:
        tree.append(new_edges)
        final_vertices = list({edge[1] for edge in new_edges})
        vertices_seen.extend(final_vertices)
        new_edges = [edge for edge in edges if edge[0] in final_vertices and edge[1] not in vertices_seen]

    return tree


def print_tree_deps_of(module, all_edges=None):
    """
    Prints the tree of modules depending on a given module.
    """
    if all_edges is None:
        all_edges = create_reverse_dependency_tree()
    tree = get_tree_starting_at(module, all_edges)

    # The list of lines is a list of tuples (line_to_be_printed, module)
    # Keeping the modules lets us know where to insert each new lines in the list.
    lines = [(tree[0], tree[0])]
    for index in range(1, len(tree)):
        edges = tree[index]
        start_edges = {edge[0] for edge in edges}

        for start in start_edges:
            end_edges = {edge[1] for edge in edges if edge[0] == start}
            # We will insert all those edges just after the line showing start.
            pos = 0
            while lines[pos][1] != start:
                pos += 1
            lines = lines[: pos + 1] + [(" " * (2 * index) + end, end) for end in end_edges] + lines[pos + 1 :]

    for line in lines:
        # We don't print the refs that where just here to help build lines.
        print(line[0])


def create_reverse_dependency_map():
    """
    Create the dependency map from module/test filename to the list of modules/tests that depend on it (even
    recursively).
    """
    modules = [
        str(f.relative_to(PATH_TO_TRANFORMERS))
        for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
    ]
    # We grab all the dependencies of each module.
    direct_deps = {m: get_module_dependencies(m) for m in modules}

    # We add all the dependencies of each test file
    tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")]
    direct_deps.update({t: get_test_dependencies(t) for t in tests})

    all_files = modules + tests

    # This recurses the dependencies
    something_changed = True
    while something_changed:
        something_changed = False
        for m in all_files:
            for d in direct_deps[m]:
                if d not in direct_deps:
                    raise ValueError(f"KeyError:{d}. From {m}")
                for dep in direct_deps[d]:
                    if dep not in direct_deps[m]:
                        direct_deps[m].append(dep)
                        something_changed = True

    # Finally we can build the reverse map.
    reverse_map = collections.defaultdict(list)
    for m in all_files:
        if m.endswith("__init__.py"):
            reverse_map[m].extend(direct_deps[m])
        for d in direct_deps[m]:
            reverse_map[d].append(m)

    return reverse_map


# Any module file that has a test name which can't be inferred automatically from its name should go here. A better
# approach is to (re-)name the test file accordingly, and second best to add the correspondence map here.
SPECIAL_MODULE_TO_TEST_MAP = {
    "commands/add_new_model_like.py": "utils/test_add_new_model_like.py",
    "configuration_utils.py": "test_configuration_common.py",
    "convert_graph_to_onnx.py": "onnx/test_onnx.py",
    "data/data_collator.py": "trainer/test_data_collator.py",
    "deepspeed.py": "deepspeed/",
    "feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py",
    "feature_extraction_utils.py": "test_feature_extraction_common.py",
    "file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
    "image_processing_utils.py": ["test_image_processing_common.py", "utils/test_image_processing_utils.py"],
    "image_transforms.py": "test_image_transforms.py",
    "utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py", "utils/test_generic.py"],
    "utils/hub.py": "utils/test_hub_utils.py",
    "modelcard.py": "utils/test_model_card.py",
    "modeling_flax_utils.py": "test_modeling_flax_common.py",
    "modeling_tf_utils.py": ["test_modeling_tf_common.py", "utils/test_modeling_tf_core.py"],
    "modeling_utils.py": ["test_modeling_common.py", "utils/test_offline.py"],
    "models/auto/modeling_auto.py": [
        "models/auto/test_modeling_auto.py",
        "models/auto/test_modeling_tf_pytorch.py",
        "models/bort/test_modeling_bort.py",
        "models/dit/test_modeling_dit.py",
    ],
    "models/auto/modeling_flax_auto.py": "models/auto/test_modeling_flax_auto.py",
    "models/auto/modeling_tf_auto.py": [
        "models/auto/test_modeling_tf_auto.py",
        "models/auto/test_modeling_tf_pytorch.py",
        "models/bort/test_modeling_tf_bort.py",
    ],
    "models/gpt2/modeling_gpt2.py": [
        "models/gpt2/test_modeling_gpt2.py",
        "models/megatron_gpt2/test_modeling_megatron_gpt2.py",
    ],
    "models/dpt/modeling_dpt.py": [
        "models/dpt/test_modeling_dpt.py",
        "models/dpt/test_modeling_dpt_hybrid.py",
    ],
    "optimization.py": "optimization/test_optimization.py",
    "optimization_tf.py": "optimization/test_optimization_tf.py",
    "pipelines/__init__.py": all_pipeline_test_files + all_model_test_files,
    "pipelines/base.py": all_pipeline_test_files + all_model_test_files,
    "pipelines/text2text_generation.py": [
        "pipelines/test_pipelines_text2text_generation.py",
        "pipelines/test_pipelines_summarization.py",
        "pipelines/test_pipelines_translation.py",
    ],
    "pipelines/zero_shot_classification.py": "pipelines/test_pipelines_zero_shot.py",
    "testing_utils.py": "utils/test_skip_decorators.py",
    "tokenization_utils.py": ["test_tokenization_common.py", "tokenization/test_tokenization_utils.py"],
    "tokenization_utils_base.py": ["test_tokenization_common.py", "tokenization/test_tokenization_utils.py"],
    "tokenization_utils_fast.py": [
        "test_tokenization_common.py",
        "tokenization/test_tokenization_utils.py",
        "tokenization/test_tokenization_fast.py",
    ],
    "trainer.py": [
        "trainer/test_trainer.py",
        "extended/test_trainer_ext.py",
        "trainer/test_trainer_distributed.py",
        "trainer/test_trainer_tpu.py",
    ],
    "train_pt_utils.py": "trainer/test_trainer_utils.py",
    "utils/versions.py": "utils/test_versions_utils.py",
}


def module_to_test_file(module_fname):
    """
    Returns the name of the file(s) where `module_fname` is tested.
    """
    splits = module_fname.split(os.path.sep)

    # Special map has priority
    short_name = os.path.sep.join(splits[2:])
    if short_name in SPECIAL_MODULE_TO_TEST_MAP:
        test_file = SPECIAL_MODULE_TO_TEST_MAP[short_name]
        if isinstance(test_file, str):
            return f"tests/{test_file}"
        return [f"tests/{f}" for f in test_file]

    module_name = splits[-1]
    # Fast tokenizers are tested in the same file as the slow ones.
    if module_name.endswith("_fast.py"):
        module_name = module_name.replace("_fast.py", ".py")

    # Special case for pipelines submodules
    if len(splits) >= 2 and splits[-2] == "pipelines":
        default_test_file = f"tests/pipelines/test_pipelines_{module_name}"
        return [default_test_file] + all_model_test_files
    # Special case for benchmarks submodules
    elif len(splits) >= 2 and splits[-2] == "benchmark":
        return ["tests/benchmark/test_benchmark.py", "tests/benchmark/test_benchmark_tf.py"]
    # Special case for commands submodules
    elif len(splits) >= 2 and splits[-2] == "commands":
        return "tests/utils/test_cli.py"
    # Special case for onnx submodules
    elif len(splits) >= 2 and splits[-2] == "onnx":
        return ["tests/onnx/test_features.py", "tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"]
    # Special case for utils (not the one in src/transformers, the ones at the root of the repo).
    elif len(splits) > 0 and splits[0] == "utils":
        default_test_file = f"tests/repo_utils/test_{module_name}"
    elif len(splits) > 4 and splits[2] == "models":
        default_test_file = f"tests/models/{splits[3]}/test_{module_name}"
    elif len(splits) > 2 and splits[2].startswith("generation"):
        default_test_file = f"tests/generation/test_{module_name}"
    elif len(splits) > 2 and splits[2].startswith("trainer"):
        default_test_file = f"tests/trainer/test_{module_name}"
    else:
        default_test_file = f"tests/utils/test_{module_name}"

    if os.path.isfile(default_test_file):
        return default_test_file

    # Processing -> processor
    if "processing" in default_test_file:
        test_file = default_test_file.replace("processing", "processor")
        if os.path.isfile(test_file):
            return test_file


# This list contains the list of test files we expect never to be launched from a change in a module/util. Those are
# launched separately.
EXPECTED_TEST_FILES_NEVER_TOUCHED = [
    "tests/generation/test_framework_agnostic.py",  # Mixins inherited by actual test classes
    "tests/mixed_int8/test_mixed_int8.py",  # Mixed-int8 bitsandbytes test
    "tests/pipelines/test_pipelines_common.py",  # Actually checked by the pipeline based file
    "tests/sagemaker/test_single_node_gpu.py",  # SageMaker test
    "tests/sagemaker/test_multi_node_model_parallel.py",  # SageMaker test
    "tests/sagemaker/test_multi_node_data_parallel.py",  # SageMaker test
    "tests/test_pipeline_mixin.py",  # Contains no test of its own (only the common tester class)
    "tests/utils/test_doc_samples.py",  # Doc tests
]


def _print_list(l):
    return "\n".join([f"- {f}" for f in l])


def sanity_check():
    """
    Checks that all test files can be touched by a modification in at least one module/utils. This test ensures that
    newly-added test files are properly mapped to some module or utils, so they can be run by the CI.
    """
    # Grab all module and utils
    all_files = [
        str(p.relative_to(PATH_TO_TRANFORMERS))
        for p in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
    ]
    all_files += [
        str(p.relative_to(PATH_TO_TRANFORMERS)) for p in (Path(PATH_TO_TRANFORMERS) / "utils").glob("**/*.py")
    ]

    # Compute all the test files we get from those.
    test_files_found = []
    for f in all_files:
        test_f = module_to_test_file(f)
        if test_f is not None:
            if isinstance(test_f, str):
                test_files_found.append(test_f)
            else:
                test_files_found.extend(test_f)

    # Some of the test files might actually be subfolders so we grab the tests inside.
    test_files = []
    for test_f in test_files_found:
        if os.path.isdir(os.path.join(PATH_TO_TRANFORMERS, test_f)):
            test_files.extend(
                [
                    str(p.relative_to(PATH_TO_TRANFORMERS))
                    for p in (Path(PATH_TO_TRANFORMERS) / test_f).glob("**/test*.py")
                ]
            )
        else:
            test_files.append(test_f)

    # Compare to existing test files
    existing_test_files = [
        str(p.relative_to(PATH_TO_TRANFORMERS)) for p in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/test*.py")
    ]
    not_touched_test_files = [f for f in existing_test_files if f not in test_files]

    should_be_tested = set(not_touched_test_files) - set(EXPECTED_TEST_FILES_NEVER_TOUCHED)
    if len(should_be_tested) > 0:
        raise ValueError(
            "The following test files are not currently associated with any module or utils files, which means they "
            f"will never get run by the CI:\n{_print_list(should_be_tested)}\n. Make sure the names of these test "
            "files match the name of the module or utils they are testing, or adapt the constant "
            "`SPECIAL_MODULE_TO_TEST_MAP` in `utils/tests_fetcher.py` to add them. If your test file is triggered "
            "separately and is not supposed to be run by the regular CI, add it to the "
            "`EXPECTED_TEST_FILES_NEVER_TOUCHED` constant instead."
        )


def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None, json_output_file=None):
    modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
    print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")

    # Create the map that will give us all impacted modules.
    impacted_modules_map = create_reverse_dependency_map()
    impacted_files = modified_files.copy()
    for f in modified_files:
        if f in impacted_modules_map:
            impacted_files.extend(impacted_modules_map[f])

    # Remove duplicates
    impacted_files = sorted(set(impacted_files))
    print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}")

    # Grab the corresponding test files:
    if "setup.py" in impacted_files:
        test_files_to_run = ["tests"]
        repo_utils_launch = True
    else:
        # Grab the corresponding test files:
        test_files_to_run = []
        for f in impacted_files:
            # Modified test files are always added
            if f.startswith("tests/"):
                test_files_to_run.append(f)
            # Example files are tested separately
            elif f.startswith("examples/pytorch"):
                test_files_to_run.append("examples/pytorch/test_pytorch_examples.py")
                test_files_to_run.append("examples/pytorch/test_accelerate_examples.py")
            elif f.startswith("examples/tensorflow"):
                test_files_to_run.append("examples/tensorflow/test_tensorflow_examples.py")
            elif f.startswith("examples/flax"):
                test_files_to_run.append("examples/flax/test_flax_examples.py")
            else:
                new_tests = module_to_test_file(f)
                if new_tests is not None:
                    if isinstance(new_tests, str):
                        test_files_to_run.append(new_tests)
                    else:
                        test_files_to_run.extend(new_tests)

        # Remove duplicates
        test_files_to_run = sorted(set(test_files_to_run))
        # Make sure we did not end up with a test file that was removed
        test_files_to_run = [f for f in test_files_to_run if os.path.isfile(f) or os.path.isdir(f)]
        if filters is not None:
            filtered_files = []
            for filter in filters:
                filtered_files.extend([f for f in test_files_to_run if f.startswith(filter)])
            test_files_to_run = filtered_files
        repo_utils_launch = any(f.split(os.path.sep)[1] == "repo_utils" for f in test_files_to_run)

    if repo_utils_launch:
        repo_util_file = Path(output_file).parent / "test_repo_utils.txt"
        with open(repo_util_file, "w", encoding="utf-8") as f:
            f.write("tests/repo_utils")

    print(f"\n### TEST TO RUN ###\n{_print_list(test_files_to_run)}")
    if len(test_files_to_run) > 0:
        with open(output_file, "w", encoding="utf-8") as f:
            f.write(" ".join(test_files_to_run))

        # Create a map that maps test categories to test files, i.e. `models/bert` -> [...test_modeling_bert.py, ...]

        # Get all test directories (and some common test files) under `tests` and `tests/models` if `test_files_to_run`
        # contains `tests` (i.e. when `setup.py` is changed).
        if "tests" in test_files_to_run:
            test_files_to_run = get_all_tests()

        if json_output_file is not None:
            test_map = {}
            for test_file in test_files_to_run:
                # `test_file` is a path to a test folder/file, starting with `tests/`. For example,
                #   - `tests/models/bert/test_modeling_bert.py` or `tests/models/bert`
                #   - `tests/trainer/test_trainer.py` or `tests/trainer`
                #   - `tests/test_modeling_common.py`
                names = test_file.split(os.path.sep)
                if names[1] == "models":
                    # take the part like `models/bert` for modeling tests
                    key = "/".join(names[1:3])
                elif len(names) > 2 or not test_file.endswith(".py"):
                    # test folders under `tests` or python files under them
                    # take the part like tokenization, `pipeline`, etc. for other test categories
                    key = "/".join(names[1:2])
                else:
                    # common test files directly under `tests/`
                    key = "common"

                if key not in test_map:
                    test_map[key] = []
                test_map[key].append(test_file)

            # sort the keys & values
            keys = sorted(test_map.keys())
            test_map = {k: " ".join(sorted(test_map[k])) for k in keys}
            with open(json_output_file, "w", encoding="UTF-8") as fp:
                json.dump(test_map, fp, ensure_ascii=False)


def filter_tests(output_file, filters):
    """
    Reads the content of the output file and filters out all the tests in a list of given folders.

    Args:
        output_file (`str` or `os.PathLike`): The path to the output file of the tests fetcher.
        filters (`List[str]`): A list of folders to filter.
    """
    if not os.path.isfile(output_file):
        print("No test file found.")
        return
    with open(output_file, "r", encoding="utf-8") as f:
        test_files = f.read().split(" ")

    if len(test_files) == 0 or test_files == [""]:
        print("No tests to filter.")
        return

    if test_files == ["tests"]:
        test_files = [os.path.join("tests", f) for f in os.listdir("tests") if f not in ["__init__.py"] + filters]
    else:
        test_files = [f for f in test_files if f.split(os.path.sep)[1] not in filters]

    with open(output_file, "w", encoding="utf-8") as f:
        f.write(" ".join(test_files))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--sanity_check", action="store_true", help="Only test that all tests and modules are accounted for."
    )
    parser.add_argument(
        "--output_file", type=str, default="test_list.txt", help="Where to store the list of tests to run"
    )
    parser.add_argument(
        "--json_output_file",
        type=str,
        default="test_map.json",
        help="Where to store the tests to run in a dictionary format mapping test categories to test files",
    )
    parser.add_argument(
        "--diff_with_last_commit",
        action="store_true",
        help="To fetch the tests between the current commit and the last commit",
    )
    parser.add_argument(
        "--filters",
        type=str,
        nargs="*",
        default=["tests"],
        help="Only keep the test files matching one of those filters.",
    )
    parser.add_argument(
        "--filter_tests",
        action="store_true",
        help="Will filter the pipeline/repo utils tests outside of the generated list of tests.",
    )
    parser.add_argument(
        "--print_dependencies_of",
        type=str,
        help="Will only print the tree of modules depending on the file passed.",
        default=None,
    )
    args = parser.parse_args()
    if args.print_dependencies_of is not None:
        print_tree_deps_of(args.print_dependencies_of)
    elif args.sanity_check:
        sanity_check()
    elif args.filter_tests:
        filter_tests(args.output_file, ["pipelines", "repo_utils"])
    else:
        repo = Repo(PATH_TO_TRANFORMERS)

        diff_with_last_commit = args.diff_with_last_commit
        if not diff_with_last_commit and not repo.head.is_detached and repo.head.ref == repo.refs.main:
            print("main branch detected, fetching tests against last commit.")
            diff_with_last_commit = True

        try:
            infer_tests_to_run(
                args.output_file,
                diff_with_last_commit=diff_with_last_commit,
                filters=args.filters,
                json_output_file=args.json_output_file,
            )
            filter_tests(args.output_file, ["repo_utils"])
        except Exception as e:
            print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
            with open(args.output_file, "w", encoding="utf-8") as f:
                if args.filters is None:
                    f.write("./tests/")
                else:
                    f.write(" ".join(args.filters))