Created
April 15, 2015 05:58
-
-
Save vrajat/3314181737ecf28ef443 to your computer and use it in GitHub Desktop.
Revisions
-
vrajat created this gist
Apr 15, 2015 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,170 @@ package com.qubole.nezha; import java.io.File; import java.io.IOException; import java.net.URISyntaxException; import java.util.List; import java.util.Map; import java.util.SortedSet; import javax.sql.DataSource; import javax.validation.Validation; import javax.validation.Validator; import com.codahale.metrics.MetricRegistry; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.Resources; import com.qubole.nezha.NezhaConfiguration; import io.dropwizard.configuration.ConfigurationFactory; import io.dropwizard.db.DataSourceFactory; import io.dropwizard.flyway.FlywayFactory; import io.dropwizard.jackson.Jackson; import org.flywaydb.core.Flyway; import org.hibernate.SessionFactory; import org.hibernate.boot.registry.StandardServiceRegistryBuilder; import org.hibernate.cfg.AvailableSettings; import org.hibernate.cfg.Configuration; import org.hibernate.context.internal.ManagedSessionContext; import org.hibernate.service.ServiceRegistry; import org.hibernate.engine.jdbc.connections.internal.DatasourceConnectionProviderImpl; import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider; import org.junit.rules.TestRule; import org.junit.runner.Description; import org.junit.runners.model.Statement; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import io.dropwizard.configuration.ConfigurationException; import io.dropwizard.db.ManagedDataSource; /* Inspired by https://gist.github.com/Sch3lp/9185192 */ public class DropwizardHibernateRule implements TestRule { private SessionFactory sessionFactory; private NezhaConfiguration nezhaConfiguration; private final ImmutableList<Class<?>> entities; private final String configPath; private final ObjectMapper MAPPER = Jackson.newObjectMapper(); private final Validator validator = Validation.buildDefaultValidatorFactory().getValidator(); public static DropwizardHibernateRule create(String configPath, ImmutableList<Class<?>> entities) { return new DropwizardHibernateRule(configPath, entities); } private DropwizardHibernateRule(String configPath, ImmutableList<Class<?>> entities) { this.configPath = configPath; this.entities = entities; } public SessionFactory getSessionFactory() { return sessionFactory; } @Override public Statement apply(final Statement base, Description description) { return new Statement() { @Override public void evaluate() throws Throwable { before(); try { base.evaluate(); } finally { after(); } } }; } protected void before() throws IOException, ConfigurationException, ClassNotFoundException, URISyntaxException { ConfigurationFactory<NezhaConfiguration> factory = new ConfigurationFactory<>( NezhaConfiguration.class, validator, MAPPER, ""); nezhaConfiguration = factory.build(new File(Resources.getResource(configPath).toURI())); final MetricRegistry metricRegistry = new MetricRegistry(); DataSourceFactory dataSourceFactory = nezhaConfiguration.getDataSourceFactory(); final ManagedDataSource dataSource = dataSourceFactory.build(metricRegistry, "Rule"); final ConnectionProvider provider = buildConnectionProvider(dataSource, dataSourceFactory.getProperties()); sessionFactory = buildSessionFactory(dataSourceFactory, provider, ImmutableMap.<String, String> of(), entities); // open session/transaction ManagedSessionContext.bind(sessionFactory.openSession()); FlywayFactory flywayFactory = nezhaConfiguration.getFlywayFactory(); final String[] fwLocations = new String[flywayFactory.getLocations().size()]; Flyway flyway = new Flyway(); flyway.setDataSource(nezhaConfiguration.getDataSourceFactory().getUrl(), nezhaConfiguration.getDataSourceFactory().getUser(), nezhaConfiguration.getDataSourceFactory().getPassword()); flyway.setLocations(flywayFactory.getLocations().toArray(fwLocations)); flyway.migrate(); } protected void after() { // close session/transaction ManagedSessionContext.unbind(sessionFactory); } /** * From io.dropwizard.hibernate.SessionFactoryFactory */ private ConnectionProvider buildConnectionProvider(DataSource dataSource, Map<String, String> properties) { final DatasourceConnectionProviderImpl connectionProvider = new DatasourceConnectionProviderImpl(); connectionProvider.setDataSource(dataSource); connectionProvider.configure(properties); return connectionProvider; } /** * From com.yammer.dropwizard.hibernate.SessionFactoryFactory */ private SessionFactory buildSessionFactory(DataSourceFactory dbConfig, ConnectionProvider connectionProvider, ImmutableMap<String, String> properties, List<Class<?>> entities) { final Configuration configuration = new Configuration(); configuration.setProperty(AvailableSettings.CURRENT_SESSION_CONTEXT_CLASS, "managed"); configuration.setProperty(AvailableSettings.USE_SQL_COMMENTS, Boolean.toString(dbConfig.isAutoCommentsEnabled())); configuration.setProperty(AvailableSettings.USE_GET_GENERATED_KEYS, "true"); configuration.setProperty(AvailableSettings.GENERATE_STATISTICS, "true"); configuration.setProperty(AvailableSettings.USE_REFLECTION_OPTIMIZER, "true"); configuration.setProperty(AvailableSettings.ORDER_UPDATES, "true"); configuration.setProperty(AvailableSettings.ORDER_INSERTS, "true"); configuration.setProperty(AvailableSettings.USE_NEW_ID_GENERATOR_MAPPINGS, "true"); configuration.setProperty("jadira.usertype.autoRegisterUserTypes", "true"); for (Map.Entry<String, String> property : properties.entrySet()) { configuration.setProperty(property.getKey(), property.getValue()); } addAnnotatedClasses(configuration, entities); final ServiceRegistry registry = new StandardServiceRegistryBuilder() .addService(ConnectionProvider.class, connectionProvider) .applySettings(properties) .build(); return configuration.buildSessionFactory(registry); } /** * From com.yammer.dropwizard.hibernate.SessionFactoryFactory */ private void addAnnotatedClasses(Configuration configuration, Iterable<Class<?>> entities) { final SortedSet<String> entityClasses = Sets.newTreeSet(); for (Class<?> klass : entities) { configuration.addAnnotatedClass(klass); entityClasses.add(klass.getCanonicalName()); } } }