Skip to content

Instantly share code, notes, and snippets.

@vrajat
Created April 15, 2015 05:58
Show Gist options
  • Select an option

  • Save vrajat/3314181737ecf28ef443 to your computer and use it in GitHub Desktop.

Select an option

Save vrajat/3314181737ecf28ef443 to your computer and use it in GitHub Desktop.

Revisions

  1. vrajat created this gist Apr 15, 2015.
    170 changes: 170 additions & 0 deletions DropwizardHibernateRule.java
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,170 @@
    package com.qubole.nezha;

    import java.io.File;
    import java.io.IOException;
    import java.net.URISyntaxException;
    import java.util.List;
    import java.util.Map;
    import java.util.SortedSet;

    import javax.sql.DataSource;
    import javax.validation.Validation;
    import javax.validation.Validator;

    import com.codahale.metrics.MetricRegistry;
    import com.fasterxml.jackson.databind.ObjectMapper;
    import com.google.common.io.Resources;
    import com.qubole.nezha.NezhaConfiguration;
    import io.dropwizard.configuration.ConfigurationFactory;
    import io.dropwizard.db.DataSourceFactory;
    import io.dropwizard.flyway.FlywayFactory;
    import io.dropwizard.jackson.Jackson;
    import org.flywaydb.core.Flyway;
    import org.hibernate.SessionFactory;
    import org.hibernate.boot.registry.StandardServiceRegistryBuilder;
    import org.hibernate.cfg.AvailableSettings;
    import org.hibernate.cfg.Configuration;
    import org.hibernate.context.internal.ManagedSessionContext;
    import org.hibernate.service.ServiceRegistry;
    import org.hibernate.engine.jdbc.connections.internal.DatasourceConnectionProviderImpl;
    import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
    import org.junit.rules.TestRule;
    import org.junit.runner.Description;
    import org.junit.runners.model.Statement;

    import com.google.common.collect.ImmutableList;
    import com.google.common.collect.ImmutableMap;
    import com.google.common.collect.Sets;
    import io.dropwizard.configuration.ConfigurationException;
    import io.dropwizard.db.ManagedDataSource;

    /*
    Inspired by https://gist.github.com/Sch3lp/9185192
    */
    public class DropwizardHibernateRule implements TestRule {

    private SessionFactory sessionFactory;
    private NezhaConfiguration nezhaConfiguration;
    private final ImmutableList<Class<?>> entities;
    private final String configPath;

    private final ObjectMapper MAPPER = Jackson.newObjectMapper();
    private final Validator validator = Validation.buildDefaultValidatorFactory().getValidator();

    public static DropwizardHibernateRule create(String configPath,
    ImmutableList<Class<?>> entities) {
    return new DropwizardHibernateRule(configPath, entities);
    }

    private DropwizardHibernateRule(String configPath, ImmutableList<Class<?>> entities) {
    this.configPath = configPath;
    this.entities = entities;
    }

    public SessionFactory getSessionFactory() {
    return sessionFactory;
    }

    @Override
    public Statement apply(final Statement base, Description description) {
    return new Statement() {

    @Override
    public void evaluate() throws Throwable {
    before();
    try {
    base.evaluate();
    } finally {
    after();
    }
    }
    };
    }

    protected void before() throws
    IOException, ConfigurationException, ClassNotFoundException, URISyntaxException {
    ConfigurationFactory<NezhaConfiguration> factory = new ConfigurationFactory<>(
    NezhaConfiguration.class, validator, MAPPER, "");
    nezhaConfiguration = factory.build(new File(Resources.getResource(configPath).toURI()));

    final MetricRegistry metricRegistry = new MetricRegistry();
    DataSourceFactory dataSourceFactory = nezhaConfiguration.getDataSourceFactory();
    final ManagedDataSource dataSource = dataSourceFactory.build(metricRegistry, "Rule");
    final ConnectionProvider provider = buildConnectionProvider(dataSource,
    dataSourceFactory.getProperties());
    sessionFactory = buildSessionFactory(dataSourceFactory, provider, ImmutableMap.<String, String> of(), entities);

    // open session/transaction
    ManagedSessionContext.bind(sessionFactory.openSession());

    FlywayFactory flywayFactory = nezhaConfiguration.getFlywayFactory();
    final String[] fwLocations = new String[flywayFactory.getLocations().size()];

    Flyway flyway = new Flyway();
    flyway.setDataSource(nezhaConfiguration.getDataSourceFactory().getUrl(),
    nezhaConfiguration.getDataSourceFactory().getUser(),
    nezhaConfiguration.getDataSourceFactory().getPassword());
    flyway.setLocations(flywayFactory.getLocations().toArray(fwLocations));
    flyway.migrate();
    }

    protected void after() {
    // close session/transaction
    ManagedSessionContext.unbind(sessionFactory);
    }

    /**
    * From io.dropwizard.hibernate.SessionFactoryFactory
    */
    private ConnectionProvider buildConnectionProvider(DataSource dataSource,
    Map<String, String> properties) {
    final DatasourceConnectionProviderImpl connectionProvider = new DatasourceConnectionProviderImpl();
    connectionProvider.setDataSource(dataSource);
    connectionProvider.configure(properties);
    return connectionProvider;
    }

    /**
    * From com.yammer.dropwizard.hibernate.SessionFactoryFactory
    */
    private SessionFactory buildSessionFactory(DataSourceFactory dbConfig,
    ConnectionProvider connectionProvider,
    ImmutableMap<String, String> properties,
    List<Class<?>> entities) {
    final Configuration configuration = new Configuration();
    configuration.setProperty(AvailableSettings.CURRENT_SESSION_CONTEXT_CLASS, "managed");
    configuration.setProperty(AvailableSettings.USE_SQL_COMMENTS, Boolean.toString(dbConfig.isAutoCommentsEnabled()));
    configuration.setProperty(AvailableSettings.USE_GET_GENERATED_KEYS, "true");
    configuration.setProperty(AvailableSettings.GENERATE_STATISTICS, "true");
    configuration.setProperty(AvailableSettings.USE_REFLECTION_OPTIMIZER, "true");
    configuration.setProperty(AvailableSettings.ORDER_UPDATES, "true");
    configuration.setProperty(AvailableSettings.ORDER_INSERTS, "true");
    configuration.setProperty(AvailableSettings.USE_NEW_ID_GENERATOR_MAPPINGS, "true");
    configuration.setProperty("jadira.usertype.autoRegisterUserTypes", "true");
    for (Map.Entry<String, String> property : properties.entrySet()) {
    configuration.setProperty(property.getKey(), property.getValue());
    }

    addAnnotatedClasses(configuration, entities);

    final ServiceRegistry registry = new StandardServiceRegistryBuilder()
    .addService(ConnectionProvider.class, connectionProvider)
    .applySettings(properties)
    .build();

    return configuration.buildSessionFactory(registry);
    }

    /**
    * From com.yammer.dropwizard.hibernate.SessionFactoryFactory
    */
    private void addAnnotatedClasses(Configuration configuration,
    Iterable<Class<?>> entities) {
    final SortedSet<String> entityClasses = Sets.newTreeSet();
    for (Class<?> klass : entities) {
    configuration.addAnnotatedClass(klass);
    entityClasses.add(klass.getCanonicalName());
    }
    }

    }