package com.xebialabs.platform.test.testng;

import java.lang.reflect.Field;
import java.util.concurrent.atomic.AtomicBoolean;
import org.testng.*;

public class RuleListener implements IInvokedMethodListener {
    private AtomicBoolean initialized = new AtomicBoolean(false);

    @Override
    public void beforeInvocation(IInvokedMethod method, ITestResult testResult) {
        if (shouldStart(method)) {
            if (!initialized.getAndSet(true)) {
                findAndStartRule(method);
            }
        }
    }

    @Override
    public void afterInvocation(IInvokedMethod method, ITestResult testResult) {
        if (shouldStop(method)) {
            if (initialized.getAndSet(false)) {
                findAndStopRule(method);
            }
        }
    }

    private boolean shouldStart(IInvokedMethod method) {
        return method.isTestMethod() || method.getTestMethod().isBeforeMethodConfiguration();
    }

    private boolean shouldStop(IInvokedMethod method) {
        ITestClass testClass = method.getTestMethod().getTestClass();
        if (testClass.getAfterTestMethods().length == 0 && method.isTestMethod()) {
            return true;
        } else if (testClass.getAfterTestMethods().length > 0) {
            ITestNGMethod[] afterTestMethods = testClass.getAfterTestMethods();
            return method.getTestMethod().equals(afterTestMethods[afterTestMethods.length - 1]);
        }
        return false;
    }

    private void findAndStopRule(IInvokedMethod method) {
        RuleBase rule = getRule(method);
        if (rule != null) {
            rule.stop();
        }
    }

    private void findAndStartRule(IInvokedMethod method) {
        RuleBase rule = getRule(method);
        if (rule != null) {
            rule.start();
        }
    }

    private RuleBase getRule(IInvokedMethod method) {
        Field ruleField = findRuleField(method);
        if (ruleField == null) return null;

        try {
            return (RuleBase) ruleField.get(method.getTestMethod().getInstance());
        } catch (IllegalAccessException e) {
            throw new TestNGException(e);
        }
    }

    private Field findRuleField(IInvokedMethod method) {
        Class<?> realClass = method.getTestMethod().getRealClass();
        Field[] fields = realClass.getFields();
        for (Field field : fields) {
            if (field.isAnnotationPresent(TestNGRule.class)) {
                return field;
            }
        }
        return null;
    }
}
