java 高并发秒杀简单实现【springboot + jpa + mysql + redis】

393人浏览 / 0人评论

参考:

java高并发秒杀活动的各种简单实现【springBoot+mybatis+redis+mysql】

依赖:

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
    <groupId>redis.clients</groupId>
    <artifactId>jedis</artifactId>
</dependency>
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-jpa</artifactId>
</dependency>
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
    <groupId>org.springdoc</groupId>
    <artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
    <version>2.0.3</version>
</dependency>
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-validation</artifactId>
</dependency>

HighConcurrencyTest.java

@Entity
@Data
@Table(name = "high_concurrency_test")
@NoArgsConstructor
@AllArgsConstructor
public class HighConcurrencyTest implements Serializable {
    private static final long serialVersionUID = 1;

    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Integer id;

    private String goodsName;

    private Integer goodsSum;

    @Version
    private Integer version;
}

HighConcurrencyTestRepository.java

import jakarta.persistence.LockModeType;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Lock;
import org.springframework.data.jpa.repository.Modifying;
import org.springframework.data.jpa.repository.Query;
import org.springframework.stereotype.Repository;
import org.springframework.transaction.annotation.Transactional;

/**
 * @Author FengHao
 * @Date 2023/3/14
 * @Description:
 **/
@Repository
public interface HighConcurrencyTestRepository extends JpaRepository<HighConcurrencyTest, Integer> {
    // 普通查询
    HighConcurrencyTest findByGoodsName(String name);

    @Transactional
    @Modifying
    @Query(value = "update HighConcurrencyTest h set h.goodsSum = ?1 where h.id = ?2")
    int updateById(Integer goodsSum, Integer id);

    @Transactional
    @Modifying
    @Query(value = "update HighConcurrencyTest h set h.goodsSum = ?1, h.version = ?2 where h.version = ?3")
    int updateByVersion(Integer goodsSum, Integer newVersion, Integer version);

    // OPTIMISTIC 写成功则 version 加 1
    // OPTIMISTIC_FORCE_INCREMENT 读成功加 1,写成功也加 1。
    @Transactional
    //@Lock(LockModeType.OPTIMISTIC) // (可选)使用乐观锁,并指定所类型
    //@Lock(LockModeType.PESSIMISTIC_WRITE) // (可选)使用悲观锁,并指定所类型
    HighConcurrencyTest findWithLockingByGoodsName(String goodsName);

}

HighConcurrencyTestController.java

import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.persistence.EntityManagerFactory;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import jakarta.annotation.Resource;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

/**
 * @Author FengHao
 * @Date 2023/3/14
 * @Description: 每次执行前手动修改 goods_sum 为 500
 **/
@RestController
// Tag 注解, 给整个接口起了个名字
@Tag(name = "java 高并发测试", description = "java 高并发测试 API")
// ApiResponses 给每个接口提供一个默认的响应, 状态码是 200, 描述是 "接口请求成功"
@ApiResponses(@ApiResponse(responseCode = "200", description = "接口请求成功"))
@RequestMapping("/highConcurrencyTest")
public class HighConcurrencyTestController {
    @Resource
    private HighConcurrencyTestRepository repository;

    @Resource
    private EntityManagerFactory entityManagerFactory;

    private final Lock lock = new ReentrantLock();

    @Resource
    private RedisUtil redisUtil;

    @Operation(summary = "不做任何处理", description = "错误演示")
    @GetMapping("/noHandle")
    public void neHandle() {
        System.err.println("noHandle");
        start(1);
    }

    @Operation(summary = "使用乐观锁", description = "乐观锁")
    @GetMapping("/useOptimisticLock")
    public void useOptimisticLock() {
        System.err.println("useOptimisticLock");
        start(2);
    }

    @Operation(summary = "使用悲观锁", description = "数据库层面悲观写锁")
    @GetMapping("/usePessimisticLock")
    public void usePessimisticLock() {
        System.err.println("usePessimisticLock");
        start(3);
    }

    @Operation(summary = "使用 java 线程同步锁", description = "use_synchronized")
    @GetMapping("/useSynchronized")
    public void useSynchronized() {
        System.err.println("useSynchronized");
        start(4);
    }

    @Operation(summary = "使用 java 线程可重入锁", description = "use_reentrantLock")
    @GetMapping("/useReentrantLock")
    public void useReentrantLock() {
        System.err.println("useReentrantLock");
        start(5);
    }

    @Operation(summary = "使用 redis", description = "利用 redis 单线程的特性")
    @GetMapping("/useRedis")
    public void useRedis() {
        System.err.println("useRedis");
        HighConcurrencyTest highConcurrencyTest = repository.findByGoodsName("Rolls-Royce");
        if (highConcurrencyTest != null && highConcurrencyTest.getGoodsSum() > 0) {
            String key = highConcurrencyTest.getGoodsName() + ":goodsSum";
            Object s = redisUtil.get(key);
            if (s == null || (Integer) s == 0) {
                redisUtil.setNumber(key, highConcurrencyTest.getGoodsSum());
            }
            start(6);
            test(highConcurrencyTest);
            start(6);
            test(highConcurrencyTest);
        }
    }

    private void test(HighConcurrencyTest highConcurrencyTest){
        new Thread(() -> {
            for (int i = 0; i <500; i++) {
                Long increment = redisUtil.increment(highConcurrencyTest.getGoodsName() + ":goodsSum", -1);
                System.err.println("increment: "  + increment);
            }
        }).start();
    }

    public void start(Integer type) {
        for (int i = 0; i < 500; i++) {
            Thread thread = new Thread(() -> {
                HighConcurrencyTest highConcurrencyTest;
                int goodsSum;
                int version;
                int newVersion;
                int result = 0;
                switch (type) {
                    case 1:
                        // 无处理
                        highConcurrencyTest = repository.findByGoodsName("Rolls-Royce");
                        if (highConcurrencyTest != null && highConcurrencyTest.getGoodsSum() > 0) {
                            highConcurrencyTest.setGoodsSum(highConcurrencyTest.getGoodsSum() - 1);
                            repository.save(highConcurrencyTest);
                        }
                        break;
                    case 2:
                        // 手动 SQL 实现乐观锁
                        highConcurrencyTest = repository.findByGoodsName("Rolls-Royce");
                        if (highConcurrencyTest != null && highConcurrencyTest.getGoodsSum() > 0) {
                            goodsSum = highConcurrencyTest.getGoodsSum() - 1;
                            version = highConcurrencyTest.getVersion();
                            newVersion = version + 1;
                            result = repository.updateByVersion(goodsSum, newVersion, version);
                        }

                        // jpa 实现乐观锁(表和实体中要有 version 字段,并且实体 version 属性上要加 @Version 注解)
                        // repository 中的方法名必须以 findWithLocking 开头
                        //HighConcurrencyTest test = repository.findWithLockingByGoodsName("Rolls-Royce");
                        //if (test != null && test.getGoodsSum() > 0) {
                        //    test.setGoodsSum(test.getGoodsSum() - 1);
                        //    try {
                        //        repository.save(test);
                        //        result = 1;
                        //    } catch (Exception e) {
                        //        //
                        //    }
                        //}
                        break;
                    case 3:
                        // 悲观锁(表和实体中要有 version 字段,并且实体 version 属性上要加 @Version 注解)
                        // repository 中的方法名必须以 findWithLocking 开头
                        HighConcurrencyTest test = repository.findWithLockingByGoodsName("Rolls-Royce");
                        if (test != null && test.getGoodsSum() > 0) {
                            // 不能使用 updateById 这种方式: result = repository.updateById(test.getGoodsSum() - 1, test.getId());
                            test.setGoodsSum(test.getGoodsSum() - 1);
                            try {
                                repository.save(test);
                                result = 1;
                            } catch (Exception e) {
                                //
                            }
                        }
                        break;
                    case 4:
                        // 使用 java 线程同步锁
                        synchronized (this) {
                            highConcurrencyTest = repository.findByGoodsName("Rolls-Royce");
                            if (highConcurrencyTest != null && highConcurrencyTest.getGoodsSum() > 0) {
                                result = repository.updateById(highConcurrencyTest.getGoodsSum() - 1, highConcurrencyTest.getId());
                            }
                        }
                        break;
                    case 5:
                        // 使用 java 可重入锁
                        lock.lock();
                        highConcurrencyTest = repository.findByGoodsName("Rolls-Royce");
                        if (highConcurrencyTest != null && highConcurrencyTest.getGoodsSum() > 0) {
                            result = repository.updateById(highConcurrencyTest.getGoodsSum() - 1, highConcurrencyTest.getId());
                        }
                        lock.unlock();
                        break;
                    case 6:
                        // 使用 redis
                        highConcurrencyTest = repository.findByGoodsName("Rolls-Royce");
                        Long increment = redisUtil.increment(highConcurrencyTest.getGoodsName() + ":goodsSum", -1);
                        System.err.println("increment: "  + increment);
                        if (increment >= 0) {
                            result = 1;
                        }
                        break;
                    default:
                }
                if (result == 1) {
                    System.err.println("success");
                } else {
                    System.err.println("fail");
                }
            });
            thread.start();
        }
    }

}

RedisConfiguration.java

/**
 * Copyright 2018 人人开源 http://www.renren.io
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package com.jbritian.springdatajpa.config;

import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.*;
import org.springframework.data.redis.serializer.*;

@Configuration
public class RedisConfiguration {
    @Autowired
    private RedisConnectionFactory redisConnectionFactory;

    @Bean
    public RedisTemplate<String, Object> redisTemplate() {
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
        redisTemplate.setKeySerializer(new StringRedisSerializer());
        redisTemplate.setHashKeySerializer(new StringRedisSerializer());
        redisTemplate.setHashValueSerializer(new StringRedisSerializer());
        redisTemplate.setValueSerializer(new GenericJackson2JsonRedisSerializer());
        redisTemplate.setConnectionFactory(redisConnectionFactory);
        return redisTemplate;
    }

    @Bean
    public HashOperations<String, String, Object> hashOperations(RedisTemplate<String, Object> redisTemplate) {
        return redisTemplate.opsForHash();
    }

    @Bean
    public ValueOperations<String, Object> valueOperations(RedisTemplate<String, Object> redisTemplate) {
        return redisTemplate.opsForValue();
    }

    @Bean
    public ListOperations<String, Object> listOperations(RedisTemplate<String, Object> redisTemplate) {
        return redisTemplate.opsForList();
    }

    @Bean
    public SetOperations<String, Object> setOperations(RedisTemplate<String, Object> redisTemplate) {
        return redisTemplate.opsForSet();
    }

    @Bean
    public ZSetOperations<String, Object> zSetOperations(RedisTemplate<String, Object> redisTemplate) {
        return redisTemplate.opsForZSet();
    }

}

SpringDocConfig.java

import io.swagger.v3.oas.models.Components;
import io.swagger.v3.oas.models.OpenAPI;
import io.swagger.v3.oas.models.info.Info;
import io.swagger.v3.oas.models.info.License;
import io.swagger.v3.oas.models.security.SecurityScheme;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

/**
 * @Author FengHao
 * @Date 2023/3/15
 * @Description: 文档访问路径 /swagger-ui/index.html
 **/
@Configuration
public class SpringDocConfig {
    @Bean
    public OpenAPI customOpenApi() {
        return new OpenAPI()
                .components(new Components()
                        // Set spring security jwt accessToken's Authenticated request header: Authorization: Bearer xxx.xxx.xxx
                        .addSecuritySchemes("authScheme", new SecurityScheme()
                                .type(SecurityScheme.Type.HTTP)
                                .bearerFormat("JWT")
                                .scheme("bearer")))
                // Set some titles
                .info(new Info()
                        .title("spring-data-jpa 的 API")
                        .version("1.0.0")
                        .description("加油↖(^ω^)↗")
                        .license(new License()
                                .name("Apache 2.0")
                                .url("https://www.apache.org/licenses/LICENSE-2.0")));
    }
}

RedisUtil.java

import com.alibaba.fastjson.JSON;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.*;
import org.springframework.stereotype.Component;

import jakarta.annotation.Resource;
import java.util.concurrent.TimeUnit;

/**
 * @Author 风仔
 * @Date 2022/12/29
 * @Version 1.0
 * @Description:
 **/
@Component
public class RedisUtil {
    @Autowired
    private RedisTemplate redisTemplate;
    @Resource(name = "redisTemplate")
    private ValueOperations<String, Object> valueOperations;
    @Resource(name = "redisTemplate")
    private HashOperations<String, String, Object> hashOperations;
    @Resource(name = "redisTemplate")
    private ListOperations<String, Object> listOperations;
    @Resource(name = "redisTemplate")
    private SetOperations<String, Object> setOperations;
    @Resource(name = "redisTemplate")
    private ZSetOperations<String, Object> zSetOperations;

    /**
     * 默认过期时长,单位:秒
     */
    public final static long DEFAULT_EXPIRE = 60 * 60 * 24;
    /**
     * 不设置过期时长
     */
    public final static long NOT_EXPIRE = -1;

    public void setNumber(String key, Object value, long expire) {
        valueOperations.set(key, value);
        if (expire != NOT_EXPIRE) {
            redisTemplate.expire(key, expire, TimeUnit.SECONDS);
        }
    }

    public void setNumber(String key, Object value) {
        setNumber(key, value, DEFAULT_EXPIRE);
    }

    public void set(String key, Object value, long expire) {
        valueOperations.set(key, toJson(value));
        if (expire != NOT_EXPIRE) {
            redisTemplate.expire(key, expire, TimeUnit.SECONDS);
        }
    }

    public void set(String key, Object value) {
        set(key, value, DEFAULT_EXPIRE);
    }

    public <T> T get(String key, Class<T> clazz, long expire) {
        Object value = valueOperations.get(key);
        if (expire != NOT_EXPIRE) {
            redisTemplate.expire(key, expire, TimeUnit.SECONDS);
        }
        return value == null ? null : fromJson((String) value, clazz);
    }

    public <T> T get(String key, Class<T> clazz) {
        return get(key, clazz, NOT_EXPIRE);
    }

    public Object get(String key, long expire) {
        Object value = valueOperations.get(key);
        if (expire != NOT_EXPIRE) {
            redisTemplate.expire(key, expire, TimeUnit.SECONDS);
        }
        return value;
    }

    public Object get(String key) {
        return get(key, NOT_EXPIRE);
    }

    public void delete(String key) {
        redisTemplate.delete(key);
    }

    public Long increment(String key, int num){
        return valueOperations.increment(key, num);
    }

    /**
     * Object转成JSON数据
     */
    private String toJson(Object object) {
        if (object instanceof Integer || object instanceof Long || object instanceof Float ||
                object instanceof Double || object instanceof Boolean || object instanceof String) {
            return String.valueOf(object);
        }
        return JSON.toJSONString(object);
    }

    /**
     * JSON数据,转成Object
     */
    private <T> T fromJson(String json, Class<T> clazz) {
        return JSON.parseObject(json, clazz);
    }
}

全部评论