package com.xebialabs.platform.test.testng;

import org.slf4j.MDC;
import org.testng.IInvokedMethod;
import org.testng.IInvokedMethodListener;
import org.testng.ITestNGMethod;
import org.testng.ITestResult;

public class MdcTestListener implements IInvokedMethodListener {
    public static final String MDC_KEY_TEST = "testcase";
    public static final String MDC_KEY_LOG = "testcaseLog";
    public static final String[] MDC_KEYS = { MDC_KEY_LOG, MDC_KEY_TEST };

    protected String getName(ITestNGMethod m_method, Object[] params) {
        String classLong = m_method.getRealClass().getCanonicalName();
        StringBuilder result = new StringBuilder(classLong).append(".").append(m_method.getMethodName()).append("(");
        int i = 0;
        for (Object p : params) {
            if (i++ > 0) {
                result.append(", ");
            }
            result.append(p);
        }
        result.append(")");

        return result.toString();
    }

    @Override
    public void beforeInvocation(IInvokedMethod method, ITestResult testResult) {
        ITestNGMethod testMethod = method.getTestMethod();
        final String name = getName(testMethod, testResult.getParameters());

        MDC.put(MDC_KEY_TEST, name);
        String nameForLog = name
            .replace("com.xebialabs.", "")
            .replaceAll("org.testng.TestRunner@[a-f0-9]+", "TestRunner")
            .replaceAll("[ \\\\/\\.,:;\\|<>'`\"!@#$%^&\\*\\+=~]+", "_");

        MDC.put(MDC_KEY_LOG, nameForLog);
    }

    @Override
    public void afterInvocation(IInvokedMethod method, ITestResult testResult) {
        for (String key : MDC_KEYS) {
            MDC.remove(key);
        }
    }
}
