package com.qubole.nezha; import java.io.File; import java.io.IOException; import java.net.URISyntaxException; import java.util.List; import java.util.Map; import java.util.SortedSet; import javax.sql.DataSource; import javax.validation.Validation; import javax.validation.Validator; import com.codahale.metrics.MetricRegistry; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.Resources; import com.qubole.nezha.NezhaConfiguration; import io.dropwizard.configuration.ConfigurationFactory; import io.dropwizard.db.DataSourceFactory; import io.dropwizard.flyway.FlywayFactory; import io.dropwizard.jackson.Jackson; import org.flywaydb.core.Flyway; import org.hibernate.SessionFactory; import org.hibernate.boot.registry.StandardServiceRegistryBuilder; import org.hibernate.cfg.AvailableSettings; import org.hibernate.cfg.Configuration; import org.hibernate.context.internal.ManagedSessionContext; import org.hibernate.service.ServiceRegistry; import org.hibernate.engine.jdbc.connections.internal.DatasourceConnectionProviderImpl; import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider; import org.junit.rules.TestRule; import org.junit.runner.Description; import org.junit.runners.model.Statement; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import io.dropwizard.configuration.ConfigurationException; import io.dropwizard.db.ManagedDataSource; /* Inspired by https://gist.github.com/Sch3lp/9185192 */ public class DropwizardHibernateRule implements TestRule { private SessionFactory sessionFactory; private NezhaConfiguration nezhaConfiguration; private final ImmutableList> entities; private final String configPath; private final ObjectMapper MAPPER = Jackson.newObjectMapper(); private final Validator validator = Validation.buildDefaultValidatorFactory().getValidator(); public static DropwizardHibernateRule create(String configPath, ImmutableList> entities) { return new DropwizardHibernateRule(configPath, entities); } private DropwizardHibernateRule(String configPath, ImmutableList> entities) { this.configPath = configPath; this.entities = entities; } public SessionFactory getSessionFactory() { return sessionFactory; } @Override public Statement apply(final Statement base, Description description) { return new Statement() { @Override public void evaluate() throws Throwable { before(); try { base.evaluate(); } finally { after(); } } }; } protected void before() throws IOException, ConfigurationException, ClassNotFoundException, URISyntaxException { ConfigurationFactory factory = new ConfigurationFactory<>( NezhaConfiguration.class, validator, MAPPER, ""); nezhaConfiguration = factory.build(new File(Resources.getResource(configPath).toURI())); final MetricRegistry metricRegistry = new MetricRegistry(); DataSourceFactory dataSourceFactory = nezhaConfiguration.getDataSourceFactory(); final ManagedDataSource dataSource = dataSourceFactory.build(metricRegistry, "Rule"); final ConnectionProvider provider = buildConnectionProvider(dataSource, dataSourceFactory.getProperties()); sessionFactory = buildSessionFactory(dataSourceFactory, provider, ImmutableMap. of(), entities); // open session/transaction ManagedSessionContext.bind(sessionFactory.openSession()); FlywayFactory flywayFactory = nezhaConfiguration.getFlywayFactory(); final String[] fwLocations = new String[flywayFactory.getLocations().size()]; Flyway flyway = new Flyway(); flyway.setDataSource(nezhaConfiguration.getDataSourceFactory().getUrl(), nezhaConfiguration.getDataSourceFactory().getUser(), nezhaConfiguration.getDataSourceFactory().getPassword()); flyway.setLocations(flywayFactory.getLocations().toArray(fwLocations)); flyway.migrate(); } protected void after() { // close session/transaction ManagedSessionContext.unbind(sessionFactory); } /** * From io.dropwizard.hibernate.SessionFactoryFactory */ private ConnectionProvider buildConnectionProvider(DataSource dataSource, Map properties) { final DatasourceConnectionProviderImpl connectionProvider = new DatasourceConnectionProviderImpl(); connectionProvider.setDataSource(dataSource); connectionProvider.configure(properties); return connectionProvider; } /** * From com.yammer.dropwizard.hibernate.SessionFactoryFactory */ private SessionFactory buildSessionFactory(DataSourceFactory dbConfig, ConnectionProvider connectionProvider, ImmutableMap properties, List> entities) { final Configuration configuration = new Configuration(); configuration.setProperty(AvailableSettings.CURRENT_SESSION_CONTEXT_CLASS, "managed"); configuration.setProperty(AvailableSettings.USE_SQL_COMMENTS, Boolean.toString(dbConfig.isAutoCommentsEnabled())); configuration.setProperty(AvailableSettings.USE_GET_GENERATED_KEYS, "true"); configuration.setProperty(AvailableSettings.GENERATE_STATISTICS, "true"); configuration.setProperty(AvailableSettings.USE_REFLECTION_OPTIMIZER, "true"); configuration.setProperty(AvailableSettings.ORDER_UPDATES, "true"); configuration.setProperty(AvailableSettings.ORDER_INSERTS, "true"); configuration.setProperty(AvailableSettings.USE_NEW_ID_GENERATOR_MAPPINGS, "true"); configuration.setProperty("jadira.usertype.autoRegisterUserTypes", "true"); for (Map.Entry property : properties.entrySet()) { configuration.setProperty(property.getKey(), property.getValue()); } addAnnotatedClasses(configuration, entities); final ServiceRegistry registry = new StandardServiceRegistryBuilder() .addService(ConnectionProvider.class, connectionProvider) .applySettings(properties) .build(); return configuration.buildSessionFactory(registry); } /** * From com.yammer.dropwizard.hibernate.SessionFactoryFactory */ private void addAnnotatedClasses(Configuration configuration, Iterable> entities) { final SortedSet entityClasses = Sets.newTreeSet(); for (Class klass : entities) { configuration.addAnnotatedClass(klass); entityClasses.add(klass.getCanonicalName()); } } }